Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-core/src/chunked_array/random.rs
6940 views
1
use num_traits::{Float, NumCast};
2
use polars_error::to_compute_err;
3
use rand::distr::Bernoulli;
4
use rand::prelude::*;
5
use rand::seq::index::IndexVec;
6
use rand_distr::{Normal, StandardNormal, StandardUniform, Uniform};
7
8
use crate::prelude::DataType::Float64;
9
use crate::prelude::*;
10
use crate::random::get_global_random_u64;
11
use crate::utils::NoNull;
12
13
fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option<u64>) -> IdxCa {
14
if len == 0 {
15
return IdxCa::new_vec(PlSmallStr::EMPTY, vec![]);
16
}
17
let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
18
let dist = Uniform::new(0, len as IdxSize).unwrap();
19
(0..n as IdxSize)
20
.map(move |_| dist.sample(&mut rng))
21
.collect_trusted::<NoNull<IdxCa>>()
22
.into_inner()
23
}
24
25
fn create_rand_index_no_replacement(
26
n: usize,
27
len: usize,
28
seed: Option<u64>,
29
shuffle: bool,
30
) -> IdxCa {
31
let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
32
let mut buf: Vec<IdxSize>;
33
if n == len {
34
buf = (0..len as IdxSize).collect();
35
if shuffle {
36
buf.shuffle(&mut rng)
37
}
38
} else {
39
// TODO: avoid extra potential copy by vendoring rand::seq::index::sample,
40
// or genericize take over slices over any unsigned type. The optimizer
41
// should get rid of the extra copy already if IdxSize matches the IndexVec
42
// size returned.
43
buf = match rand::seq::index::sample(&mut rng, len, n) {
44
IndexVec::U32(v) => v.into_iter().map(|x| x as IdxSize).collect(),
45
#[cfg(target_pointer_width = "64")]
46
IndexVec::U64(v) => v.into_iter().map(|x| x as IdxSize).collect(),
47
};
48
}
49
IdxCa::new_vec(PlSmallStr::EMPTY, buf)
50
}
51
52
impl<T> ChunkedArray<T>
53
where
54
T: PolarsNumericType,
55
StandardUniform: Distribution<T::Native>,
56
{
57
pub fn init_rand(size: usize, null_density: f32, seed: Option<u64>) -> Self {
58
let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
59
(0..size)
60
.map(|_| {
61
if rng.random::<f32>() < null_density {
62
None
63
} else {
64
Some(rng.random())
65
}
66
})
67
.collect()
68
}
69
}
70
71
fn ensure_shape(n: usize, len: usize, with_replacement: bool) -> PolarsResult<()> {
72
polars_ensure!(
73
with_replacement || n <= len,
74
ShapeMismatch:
75
"cannot take a larger sample than the total population when `with_replacement=false`"
76
);
77
Ok(())
78
}
79
80
impl Series {
81
pub fn sample_n(
82
&self,
83
n: usize,
84
with_replacement: bool,
85
shuffle: bool,
86
seed: Option<u64>,
87
) -> PolarsResult<Self> {
88
ensure_shape(n, self.len(), with_replacement)?;
89
if n == 0 {
90
return Ok(self.clear());
91
}
92
let len = self.len();
93
94
match with_replacement {
95
true => {
96
let idx = create_rand_index_with_replacement(n, len, seed);
97
debug_assert_eq!(len, self.len());
98
// SAFETY: we know that we never go out of bounds.
99
unsafe { Ok(self.take_unchecked(&idx)) }
100
},
101
false => {
102
let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
103
debug_assert_eq!(len, self.len());
104
// SAFETY: we know that we never go out of bounds.
105
unsafe { Ok(self.take_unchecked(&idx)) }
106
},
107
}
108
}
109
110
/// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`].
111
pub fn sample_frac(
112
&self,
113
frac: f64,
114
with_replacement: bool,
115
shuffle: bool,
116
seed: Option<u64>,
117
) -> PolarsResult<Self> {
118
let n = (self.len() as f64 * frac) as usize;
119
self.sample_n(n, with_replacement, shuffle, seed)
120
}
121
122
pub fn shuffle(&self, seed: Option<u64>) -> Self {
123
let len = self.len();
124
let n = len;
125
let idx = create_rand_index_no_replacement(n, len, seed, true);
126
debug_assert_eq!(len, self.len());
127
// SAFETY: we know that we never go out of bounds.
128
unsafe { self.take_unchecked(&idx) }
129
}
130
}
131
132
impl<T> ChunkedArray<T>
133
where
134
T: PolarsDataType,
135
ChunkedArray<T>: ChunkTake<IdxCa>,
136
{
137
/// Sample n datapoints from this [`ChunkedArray`].
138
pub fn sample_n(
139
&self,
140
n: usize,
141
with_replacement: bool,
142
shuffle: bool,
143
seed: Option<u64>,
144
) -> PolarsResult<Self> {
145
ensure_shape(n, self.len(), with_replacement)?;
146
let len = self.len();
147
148
match with_replacement {
149
true => {
150
let idx = create_rand_index_with_replacement(n, len, seed);
151
debug_assert_eq!(len, self.len());
152
// SAFETY: we know that we never go out of bounds.
153
unsafe { Ok(self.take_unchecked(&idx)) }
154
},
155
false => {
156
let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
157
debug_assert_eq!(len, self.len());
158
// SAFETY: we know that we never go out of bounds.
159
unsafe { Ok(self.take_unchecked(&idx)) }
160
},
161
}
162
}
163
164
/// Sample a fraction between 0.0-1.0 of this [`ChunkedArray`].
165
pub fn sample_frac(
166
&self,
167
frac: f64,
168
with_replacement: bool,
169
shuffle: bool,
170
seed: Option<u64>,
171
) -> PolarsResult<Self> {
172
let n = (self.len() as f64 * frac) as usize;
173
self.sample_n(n, with_replacement, shuffle, seed)
174
}
175
}
176
177
impl DataFrame {
178
/// Sample n datapoints from this [`DataFrame`].
179
pub fn sample_n(
180
&self,
181
n: &Series,
182
with_replacement: bool,
183
shuffle: bool,
184
seed: Option<u64>,
185
) -> PolarsResult<Self> {
186
polars_ensure!(
187
n.len() == 1,
188
ComputeError: "Sample size must be a single value."
189
);
190
191
let n = n.cast(&IDX_DTYPE)?;
192
let n = n.idx()?;
193
194
match n.get(0) {
195
Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed),
196
None => Ok(self.clear()),
197
}
198
}
199
200
pub fn sample_n_literal(
201
&self,
202
n: usize,
203
with_replacement: bool,
204
shuffle: bool,
205
seed: Option<u64>,
206
) -> PolarsResult<Self> {
207
ensure_shape(n, self.height(), with_replacement)?;
208
// All columns should used the same indices. So we first create the indices.
209
let idx = match with_replacement {
210
true => create_rand_index_with_replacement(n, self.height(), seed),
211
false => create_rand_index_no_replacement(n, self.height(), seed, shuffle),
212
};
213
// SAFETY: the indices are within bounds.
214
Ok(unsafe { self.take_unchecked(&idx) })
215
}
216
217
/// Sample a fraction between 0.0-1.0 of this [`DataFrame`].
218
pub fn sample_frac(
219
&self,
220
frac: &Series,
221
with_replacement: bool,
222
shuffle: bool,
223
seed: Option<u64>,
224
) -> PolarsResult<Self> {
225
polars_ensure!(
226
frac.len() == 1,
227
ComputeError: "Sample fraction must be a single value."
228
);
229
230
let frac = frac.cast(&Float64)?;
231
let frac = frac.f64()?;
232
233
match frac.get(0) {
234
Some(frac) => {
235
let n = (self.height() as f64 * frac) as usize;
236
self.sample_n_literal(n, with_replacement, shuffle, seed)
237
},
238
None => Ok(self.clear()),
239
}
240
}
241
}
242
243
impl<T> ChunkedArray<T>
244
where
245
T: PolarsNumericType,
246
T::Native: Float,
247
{
248
/// Create [`ChunkedArray`] with samples from a Normal distribution.
249
pub fn rand_normal(
250
name: PlSmallStr,
251
length: usize,
252
mean: f64,
253
std_dev: f64,
254
) -> PolarsResult<Self> {
255
let normal = Normal::new(mean, std_dev).map_err(to_compute_err)?;
256
let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
257
let mut rng = rand::rng();
258
for _ in 0..length {
259
let smpl = normal.sample(&mut rng);
260
let smpl = NumCast::from(smpl).unwrap();
261
builder.append_value(smpl)
262
}
263
Ok(builder.finish())
264
}
265
266
/// Create [`ChunkedArray`] with samples from a Standard Normal distribution.
267
pub fn rand_standard_normal(name: PlSmallStr, length: usize) -> Self {
268
let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
269
let mut rng = rand::rng();
270
for _ in 0..length {
271
let smpl: f64 = rng.sample(StandardNormal);
272
let smpl = NumCast::from(smpl).unwrap();
273
builder.append_value(smpl)
274
}
275
builder.finish()
276
}
277
278
/// Create [`ChunkedArray`] with samples from a Uniform distribution.
279
pub fn rand_uniform(name: PlSmallStr, length: usize, low: f64, high: f64) -> Self {
280
let uniform = Uniform::new(low, high).unwrap();
281
let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
282
let mut rng = rand::rng();
283
for _ in 0..length {
284
let smpl = uniform.sample(&mut rng);
285
let smpl = NumCast::from(smpl).unwrap();
286
builder.append_value(smpl)
287
}
288
builder.finish()
289
}
290
}
291
292
impl BooleanChunked {
293
/// Create [`ChunkedArray`] with samples from a Bernoulli distribution.
294
pub fn rand_bernoulli(name: PlSmallStr, length: usize, p: f64) -> PolarsResult<Self> {
295
let dist = Bernoulli::new(p).map_err(to_compute_err)?;
296
let mut rng = rand::rng();
297
let mut builder = BooleanChunkedBuilder::new(name, length);
298
for _ in 0..length {
299
let smpl = dist.sample(&mut rng);
300
builder.append_value(smpl)
301
}
302
Ok(builder.finish())
303
}
304
}
305
306
#[cfg(test)]
307
mod test {
308
use super::*;
309
310
#[test]
311
fn test_sample() {
312
let df = df![
313
"foo" => &[1, 2, 3, 4, 5]
314
]
315
.unwrap();
316
317
// Default samples are random and don't require seeds.
318
assert!(
319
df.sample_n(
320
&Series::new(PlSmallStr::from_static("s"), &[3]),
321
false,
322
false,
323
None
324
)
325
.is_ok()
326
);
327
assert!(
328
df.sample_frac(
329
&Series::new(PlSmallStr::from_static("frac"), &[0.4]),
330
false,
331
false,
332
None
333
)
334
.is_ok()
335
);
336
// With seeding.
337
assert!(
338
df.sample_n(
339
&Series::new(PlSmallStr::from_static("s"), &[3]),
340
false,
341
false,
342
Some(0)
343
)
344
.is_ok()
345
);
346
assert!(
347
df.sample_frac(
348
&Series::new(PlSmallStr::from_static("frac"), &[0.4]),
349
false,
350
false,
351
Some(0)
352
)
353
.is_ok()
354
);
355
// Without replacement can not sample more than 100%.
356
assert!(
357
df.sample_frac(
358
&Series::new(PlSmallStr::from_static("frac"), &[2.0]),
359
false,
360
false,
361
Some(0)
362
)
363
.is_err()
364
);
365
assert!(
366
df.sample_n(
367
&Series::new(PlSmallStr::from_static("s"), &[3]),
368
true,
369
false,
370
Some(0)
371
)
372
.is_ok()
373
);
374
assert!(
375
df.sample_frac(
376
&Series::new(PlSmallStr::from_static("frac"), &[0.4]),
377
true,
378
false,
379
Some(0)
380
)
381
.is_ok()
382
);
383
// With replacement can sample more than 100%.
384
assert!(
385
df.sample_frac(
386
&Series::new(PlSmallStr::from_static("frac"), &[2.0]),
387
true,
388
false,
389
Some(0)
390
)
391
.is_ok()
392
);
393
}
394
}
395
396