Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-core/src/series/proptest.rs
7884 views
1
use std::ops::RangeInclusive;
2
use std::rc::Rc;
3
use std::sync::Arc;
4
use std::sync::atomic::{AtomicUsize, Ordering};
5
6
use arrow::bitmap::bitmask::nth_set_bit_u32;
7
#[cfg(feature = "dtype-categorical")]
8
use polars_dtype::categorical::{CategoricalMapping, Categories, FrozenCategories};
9
use proptest::prelude::*;
10
11
use crate::chunked_array::builder::AnonymousListBuilder;
12
#[cfg(feature = "dtype-categorical")]
13
use crate::chunked_array::builder::CategoricalChunkedBuilder;
14
use crate::prelude::{Int32Chunked, Int64Chunked, Int128Chunked, NamedFrom, Series, TimeUnit};
15
#[cfg(feature = "dtype-struct")]
16
use crate::series::StructChunked;
17
use crate::series::from::IntoSeries;
18
#[cfg(feature = "dtype-categorical")]
19
use crate::series::{Categorical8Type, DataType};
20
21
// A global, thread-safe counter that will be used to ensure unique column names when the Series are created
22
// This is especially useful for when the Series strategies are combined to create a DataFrame strategy
23
static COUNTER: AtomicUsize = AtomicUsize::new(0);
24
25
fn next_column_name() -> String {
26
format!("col_{}", COUNTER.fetch_add(1, Ordering::Relaxed))
27
}
28
29
bitflags::bitflags! {
30
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31
pub struct SeriesArbitrarySelection: u32 {
32
const BOOLEAN = 1;
33
const UINT = 1 << 1;
34
const INT = 1 << 2;
35
const FLOAT = 1 << 3;
36
const STRING = 1 << 4;
37
const BINARY = 1 << 5;
38
39
const TIME = 1 << 6;
40
const DATETIME = 1 << 7;
41
const DATE = 1 << 8;
42
const DURATION = 1 << 9;
43
const DECIMAL = 1 << 10;
44
const CATEGORICAL = 1 << 11;
45
const ENUM = 1 << 12;
46
47
const LIST = 1 << 13;
48
const ARRAY = 1 << 14;
49
const STRUCT = 1 << 15;
50
}
51
}
52
53
impl SeriesArbitrarySelection {
54
pub fn physical() -> Self {
55
Self::BOOLEAN | Self::UINT | Self::INT | Self::FLOAT | Self::STRING | Self::BINARY
56
}
57
58
pub fn logical() -> Self {
59
Self::TIME
60
| Self::DATETIME
61
| Self::DATE
62
| Self::DURATION
63
| Self::DECIMAL
64
| Self::CATEGORICAL
65
| Self::ENUM
66
}
67
68
pub fn nested() -> Self {
69
Self::LIST | Self::ARRAY | Self::STRUCT
70
}
71
}
72
73
#[derive(Clone)]
74
pub struct SeriesArbitraryOptions {
75
pub allowed_dtypes: SeriesArbitrarySelection,
76
pub max_nesting_level: usize,
77
pub series_length_range: RangeInclusive<usize>,
78
pub categories_range: RangeInclusive<usize>,
79
pub struct_fields_range: RangeInclusive<usize>,
80
}
81
82
impl Default for SeriesArbitraryOptions {
83
fn default() -> Self {
84
Self {
85
allowed_dtypes: SeriesArbitrarySelection::all(),
86
max_nesting_level: 3,
87
series_length_range: 0..=5,
88
categories_range: 0..=3,
89
struct_fields_range: 0..=3,
90
}
91
}
92
}
93
94
pub fn series_strategy(
95
options: Rc<SeriesArbitraryOptions>,
96
nesting_level: usize,
97
) -> impl Strategy<Value = Series> {
98
use SeriesArbitrarySelection as S;
99
100
let mut allowed_dtypes = options.allowed_dtypes;
101
102
if options.max_nesting_level <= nesting_level {
103
allowed_dtypes &= !S::nested()
104
}
105
106
let num_possible_types = allowed_dtypes.bits().count_ones();
107
assert!(num_possible_types > 0);
108
109
(0..num_possible_types).prop_flat_map(move |i| {
110
let selection =
111
S::from_bits_retain(1 << nth_set_bit_u32(options.allowed_dtypes.bits(), i).unwrap());
112
113
match selection {
114
_ if selection == S::BOOLEAN => {
115
series_boolean_strategy(options.series_length_range.clone()).boxed()
116
},
117
_ if selection == S::UINT => {
118
series_uint_strategy(options.series_length_range.clone()).boxed()
119
},
120
_ if selection == S::INT => {
121
series_int_strategy(options.series_length_range.clone()).boxed()
122
},
123
_ if selection == S::FLOAT => {
124
series_float_strategy(options.series_length_range.clone()).boxed()
125
},
126
_ if selection == S::STRING => {
127
series_string_strategy(options.series_length_range.clone()).boxed()
128
},
129
_ if selection == S::BINARY => {
130
series_binary_strategy(options.series_length_range.clone()).boxed()
131
},
132
#[cfg(feature = "dtype-time")]
133
_ if selection == S::TIME => {
134
series_time_strategy(options.series_length_range.clone()).boxed()
135
},
136
#[cfg(feature = "dtype-datetime")]
137
_ if selection == S::DATETIME => {
138
series_datetime_strategy(options.series_length_range.clone()).boxed()
139
},
140
#[cfg(feature = "dtype-date")]
141
_ if selection == S::DATE => {
142
series_date_strategy(options.series_length_range.clone()).boxed()
143
},
144
#[cfg(feature = "dtype-duration")]
145
_ if selection == S::DURATION => {
146
series_duration_strategy(options.series_length_range.clone()).boxed()
147
},
148
#[cfg(feature = "dtype-decimal")]
149
_ if selection == S::DECIMAL => {
150
series_decimal_strategy(options.series_length_range.clone()).boxed()
151
},
152
#[cfg(feature = "dtype-categorical")]
153
_ if selection == S::CATEGORICAL => series_categorical_strategy(
154
options.series_length_range.clone(),
155
options.categories_range.clone(),
156
)
157
.boxed(),
158
#[cfg(feature = "dtype-categorical")]
159
_ if selection == S::ENUM => series_enum_strategy(
160
options.series_length_range.clone(),
161
options.categories_range.clone(),
162
)
163
.boxed(),
164
_ if selection == S::LIST => series_list_strategy(
165
series_strategy(options.clone(), nesting_level + 1),
166
options.series_length_range.clone(),
167
)
168
.boxed(),
169
#[cfg(feature = "dtype-array")]
170
_ if selection == S::ARRAY => series_array_strategy(
171
series_strategy(options.clone(), nesting_level + 1),
172
options.series_length_range.clone(),
173
)
174
.boxed(),
175
#[cfg(feature = "dtype-struct")]
176
_ if selection == S::STRUCT => series_struct_strategy(
177
series_strategy(options.clone(), nesting_level + 1),
178
options.struct_fields_range.clone(),
179
)
180
.boxed(),
181
_ => unreachable!(),
182
}
183
})
184
}
185
186
fn series_boolean_strategy(
187
series_length_range: RangeInclusive<usize>,
188
) -> impl Strategy<Value = Series> {
189
prop::collection::vec(any::<bool>(), series_length_range)
190
.prop_map(|bools| Series::new(next_column_name().into(), bools))
191
}
192
193
fn series_uint_strategy(
194
series_length_range: RangeInclusive<usize>,
195
) -> impl Strategy<Value = Series> {
196
prop_oneof![
197
prop::collection::vec(any::<u8>(), series_length_range.clone())
198
.prop_map(|uints| Series::new(next_column_name().into(), uints)),
199
prop::collection::vec(any::<u16>(), series_length_range.clone())
200
.prop_map(|uints| Series::new(next_column_name().into(), uints)),
201
prop::collection::vec(any::<u32>(), series_length_range.clone())
202
.prop_map(|uints| Series::new(next_column_name().into(), uints)),
203
prop::collection::vec(any::<u64>(), series_length_range.clone())
204
.prop_map(|uints| Series::new(next_column_name().into(), uints)),
205
prop::collection::vec(any::<u128>(), series_length_range)
206
.prop_map(|uints| Series::new(next_column_name().into(), uints)),
207
]
208
}
209
210
fn series_int_strategy(
211
series_length_range: RangeInclusive<usize>,
212
) -> impl Strategy<Value = Series> {
213
prop_oneof![
214
prop::collection::vec(any::<i8>(), series_length_range.clone())
215
.prop_map(|ints| Series::new(next_column_name().into(), ints)),
216
prop::collection::vec(any::<i16>(), series_length_range.clone())
217
.prop_map(|ints| Series::new(next_column_name().into(), ints)),
218
prop::collection::vec(any::<i32>(), series_length_range.clone())
219
.prop_map(|ints| Series::new(next_column_name().into(), ints)),
220
prop::collection::vec(any::<i64>(), series_length_range.clone())
221
.prop_map(|ints| Series::new(next_column_name().into(), ints)),
222
prop::collection::vec(any::<i128>(), series_length_range)
223
.prop_map(|ints| Series::new(next_column_name().into(), ints)),
224
]
225
}
226
227
fn series_float_strategy(
228
series_length_range: RangeInclusive<usize>,
229
) -> impl Strategy<Value = Series> {
230
prop_oneof![
231
prop::collection::vec(any::<f32>(), series_length_range.clone())
232
.prop_map(|floats| Series::new(next_column_name().into(), floats)),
233
prop::collection::vec(any::<f64>(), series_length_range)
234
.prop_map(|floats| Series::new(next_column_name().into(), floats)),
235
]
236
}
237
238
fn series_string_strategy(
239
series_length_range: RangeInclusive<usize>,
240
) -> impl Strategy<Value = Series> {
241
prop::collection::vec(any::<String>(), series_length_range)
242
.prop_map(|strings| Series::new(next_column_name().into(), strings))
243
}
244
245
fn series_binary_strategy(
246
series_length_range: RangeInclusive<usize>,
247
) -> impl Strategy<Value = Series> {
248
prop::collection::vec(any::<u8>(), series_length_range)
249
.prop_map(|binaries| Series::new(next_column_name().into(), binaries))
250
}
251
252
#[cfg(feature = "dtype-time")]
253
fn series_time_strategy(
254
series_length_range: RangeInclusive<usize>,
255
) -> impl Strategy<Value = Series> {
256
prop::collection::vec(
257
0i64..86_400_000_000_000i64, // Time range: 0 to just under 24 hours in nanoseconds
258
series_length_range,
259
)
260
.prop_map(|times| {
261
Int64Chunked::new(next_column_name().into(), &times)
262
.into_time()
263
.into_series()
264
})
265
}
266
267
#[cfg(feature = "dtype-datetime")]
268
fn series_datetime_strategy(
269
series_length_range: RangeInclusive<usize>,
270
) -> impl Strategy<Value = Series> {
271
prop::collection::vec(
272
0i64..i64::MAX, // Datetime range: 0 (1970-01-01) to i64::MAX in milliseconds since UNIX epoch
273
series_length_range,
274
)
275
.prop_map(|datetimes| {
276
Int64Chunked::new(next_column_name().into(), &datetimes)
277
.into_datetime(TimeUnit::Milliseconds, None)
278
.into_series()
279
})
280
}
281
282
#[cfg(feature = "dtype-date")]
283
fn series_date_strategy(
284
series_length_range: RangeInclusive<usize>,
285
) -> impl Strategy<Value = Series> {
286
prop::collection::vec(
287
0i32..50_000i32, // Date range: 0 (1970-01-01) to ~50,000 days (~137 years, roughly 1970-2107)
288
series_length_range,
289
)
290
.prop_map(|dates| {
291
Int32Chunked::new(next_column_name().into(), &dates)
292
.into_date()
293
.into_series()
294
})
295
}
296
297
#[cfg(feature = "dtype-duration")]
298
fn series_duration_strategy(
299
series_length_range: RangeInclusive<usize>,
300
) -> impl Strategy<Value = Series> {
301
prop::collection::vec(
302
i64::MIN..i64::MAX, // Duration range: full i64 range in milliseconds (can be negative for time differences)
303
series_length_range,
304
)
305
.prop_map(|durations| {
306
Int64Chunked::new(next_column_name().into(), &durations)
307
.into_duration(TimeUnit::Milliseconds)
308
.into_series()
309
})
310
}
311
312
#[cfg(feature = "dtype-decimal")]
313
fn series_decimal_strategy(
314
series_length_range: RangeInclusive<usize>,
315
) -> impl Strategy<Value = Series> {
316
prop::collection::vec(i128::MIN..i128::MAX, series_length_range).prop_map(|decimals| {
317
Int128Chunked::new(next_column_name().into(), &decimals)
318
.into_decimal_unchecked(38, 9) // precision = 38 (max for i128), scale = 9 (9 decimal places)
319
.into_series()
320
})
321
}
322
323
#[cfg(feature = "dtype-categorical")]
324
fn series_categorical_strategy(
325
series_length_range: RangeInclusive<usize>,
326
categories_range: RangeInclusive<usize>,
327
) -> impl Strategy<Value = Series> {
328
categories_range
329
.prop_flat_map(move |n_categories| {
330
let possible_categories: Vec<String> =
331
(0..n_categories).map(|i| format!("category{i}")).collect();
332
333
prop::collection::vec(
334
prop::sample::select(possible_categories),
335
series_length_range.clone(),
336
)
337
})
338
.prop_map(|categories| {
339
// Using Categorical8Type (u8 backing) which supports up to 256 unique categories
340
let mut builder = CategoricalChunkedBuilder::<Categorical8Type>::new(
341
next_column_name().into(),
342
DataType::Categorical(Categories::global(), Arc::new(CategoricalMapping::new(256))),
343
);
344
345
for category in categories {
346
builder.append_str(&category).unwrap();
347
}
348
349
builder.finish().into_series()
350
})
351
}
352
353
#[cfg(feature = "dtype-categorical")]
354
fn series_enum_strategy(
355
series_length_range: RangeInclusive<usize>,
356
categories_range: RangeInclusive<usize>,
357
) -> impl Strategy<Value = Series> {
358
categories_range
359
.prop_flat_map(move |n_categories| {
360
let possible_categories: Vec<String> =
361
(0..n_categories).map(|i| format!("category{i}")).collect();
362
363
(
364
Just(possible_categories.clone()),
365
prop::collection::vec(
366
prop::sample::select(possible_categories),
367
series_length_range.clone(),
368
),
369
)
370
})
371
.prop_map(|(possible_categories, sampled_categories)| {
372
let frozen_categories =
373
FrozenCategories::new(possible_categories.iter().map(|s| s.as_str())).unwrap();
374
let mapping = frozen_categories.mapping().clone();
375
376
// Using Categorical8Type (u8 backing) which supports up to 256 unique categories
377
let mut builder = CategoricalChunkedBuilder::<Categorical8Type>::new(
378
next_column_name().into(),
379
DataType::Enum(frozen_categories, mapping),
380
);
381
382
for category in sampled_categories {
383
builder.append_str(&category).unwrap();
384
}
385
386
builder.finish().into_series()
387
})
388
}
389
390
fn series_list_strategy(
391
inner: impl Strategy<Value = Series>,
392
series_length_range: RangeInclusive<usize>,
393
) -> impl Strategy<Value = Series> {
394
inner.prop_flat_map(move |sample_series| {
395
series_length_range.clone().prop_map(move |num_lists| {
396
let mut builder = AnonymousListBuilder::new(
397
next_column_name().into(),
398
num_lists,
399
Some(sample_series.dtype().clone()),
400
);
401
402
for _ in 0..num_lists {
403
builder.append_series(&sample_series).unwrap();
404
}
405
406
builder.finish().into_series()
407
})
408
})
409
}
410
411
#[cfg(feature = "dtype-array")]
412
fn series_array_strategy(
413
inner: impl Strategy<Value = Series>,
414
series_length_range: RangeInclusive<usize>,
415
) -> impl Strategy<Value = Series> {
416
inner.prop_flat_map(move |sample_series| {
417
series_length_range.clone().prop_map(move |num_arrays| {
418
let width = sample_series.len();
419
420
let mut builder = AnonymousListBuilder::new(
421
next_column_name().into(),
422
num_arrays,
423
Some(sample_series.dtype().clone()),
424
);
425
426
for _ in 0..num_arrays {
427
builder.append_series(&sample_series).unwrap();
428
}
429
430
let list_series = builder.finish().into_series();
431
432
list_series
433
.cast(&DataType::Array(
434
Box::new(sample_series.dtype().clone()),
435
width,
436
))
437
.unwrap()
438
})
439
})
440
}
441
442
#[cfg(feature = "dtype-struct")]
443
fn series_struct_strategy(
444
inner: impl Strategy<Value = Series>,
445
struct_fields_range: RangeInclusive<usize>,
446
) -> impl Strategy<Value = Series> {
447
inner.prop_flat_map(move |sample_series| {
448
struct_fields_range.clone().prop_map(move |num_fields| {
449
let length = sample_series.len();
450
451
let fields: Vec<Series> = (0..num_fields)
452
.map(|i| {
453
let mut field = sample_series.clone();
454
field.rename(format!("field_{}", i).into());
455
field
456
})
457
.collect();
458
459
StructChunked::from_series(next_column_name().into(), length, fields.iter())
460
.unwrap()
461
.into_series()
462
})
463
})
464
}
465
466