Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-core/src/series/proptest.rs
8430 views
1
use std::ops::RangeInclusive;
2
use std::rc::Rc;
3
use std::sync::atomic::{AtomicUsize, Ordering};
4
5
use arrow::bitmap::bitmask::nth_set_bit_u32;
6
#[cfg(feature = "dtype-categorical")]
7
use polars_dtype::categorical::{Categories, FrozenCategories};
8
use proptest::prelude::*;
9
10
use crate::chunked_array::builder::AnonymousListBuilder;
11
#[cfg(feature = "dtype-categorical")]
12
use crate::chunked_array::builder::CategoricalChunkedBuilder;
13
use crate::prelude::{Int32Chunked, Int64Chunked, Int128Chunked, NamedFrom, Series, TimeUnit};
14
#[cfg(feature = "dtype-struct")]
15
use crate::series::StructChunked;
16
use crate::series::from::IntoSeries;
17
#[cfg(feature = "dtype-categorical")]
18
use crate::series::{Categorical8Type, DataType};
19
20
// A global, thread-safe counter that will be used to ensure unique column names when the Series are created
21
// This is especially useful for when the Series strategies are combined to create a DataFrame strategy
22
static COUNTER: AtomicUsize = AtomicUsize::new(0);
23
24
fn next_column_name() -> String {
25
format!("col_{}", COUNTER.fetch_add(1, Ordering::Relaxed))
26
}
27
28
bitflags::bitflags! {
29
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30
pub struct SeriesArbitrarySelection: u32 {
31
const BOOLEAN = 1;
32
const UINT = 1 << 1;
33
const INT = 1 << 2;
34
const FLOAT = 1 << 3;
35
const STRING = 1 << 4;
36
const BINARY = 1 << 5;
37
38
const TIME = 1 << 6;
39
const DATETIME = 1 << 7;
40
const DATE = 1 << 8;
41
const DURATION = 1 << 9;
42
const DECIMAL = 1 << 10;
43
const CATEGORICAL = 1 << 11;
44
const ENUM = 1 << 12;
45
46
const LIST = 1 << 13;
47
const ARRAY = 1 << 14;
48
const STRUCT = 1 << 15;
49
}
50
}
51
52
impl SeriesArbitrarySelection {
53
pub fn physical() -> Self {
54
Self::BOOLEAN | Self::UINT | Self::INT | Self::FLOAT | Self::STRING | Self::BINARY
55
}
56
57
pub fn logical() -> Self {
58
Self::TIME
59
| Self::DATETIME
60
| Self::DATE
61
| Self::DURATION
62
| Self::DECIMAL
63
| Self::CATEGORICAL
64
| Self::ENUM
65
}
66
67
pub fn nested() -> Self {
68
Self::LIST | Self::ARRAY | Self::STRUCT
69
}
70
}
71
72
#[derive(Clone)]
73
pub struct SeriesArbitraryOptions {
74
pub allowed_dtypes: SeriesArbitrarySelection,
75
pub max_nesting_level: usize,
76
pub series_length_range: RangeInclusive<usize>,
77
pub categories_range: RangeInclusive<usize>,
78
pub struct_fields_range: RangeInclusive<usize>,
79
}
80
81
impl Default for SeriesArbitraryOptions {
82
fn default() -> Self {
83
Self {
84
allowed_dtypes: SeriesArbitrarySelection::all(),
85
max_nesting_level: 3,
86
series_length_range: 0..=5,
87
categories_range: 0..=3,
88
struct_fields_range: 0..=3,
89
}
90
}
91
}
92
93
pub fn series_strategy(
94
options: Rc<SeriesArbitraryOptions>,
95
nesting_level: usize,
96
) -> impl Strategy<Value = Series> {
97
use SeriesArbitrarySelection as S;
98
99
let mut allowed_dtypes = options.allowed_dtypes;
100
101
if options.max_nesting_level <= nesting_level {
102
allowed_dtypes &= !S::nested()
103
}
104
105
let num_possible_types = allowed_dtypes.bits().count_ones();
106
assert!(num_possible_types > 0);
107
108
(0..num_possible_types).prop_flat_map(move |i| {
109
let selection =
110
S::from_bits_retain(1 << nth_set_bit_u32(options.allowed_dtypes.bits(), i).unwrap());
111
112
match selection {
113
_ if selection == S::BOOLEAN => {
114
series_boolean_strategy(options.series_length_range.clone()).boxed()
115
},
116
_ if selection == S::UINT => {
117
series_uint_strategy(options.series_length_range.clone()).boxed()
118
},
119
_ if selection == S::INT => {
120
series_int_strategy(options.series_length_range.clone()).boxed()
121
},
122
_ if selection == S::FLOAT => {
123
series_float_strategy(options.series_length_range.clone()).boxed()
124
},
125
_ if selection == S::STRING => {
126
series_string_strategy(options.series_length_range.clone()).boxed()
127
},
128
_ if selection == S::BINARY => {
129
series_binary_strategy(options.series_length_range.clone()).boxed()
130
},
131
#[cfg(feature = "dtype-time")]
132
_ if selection == S::TIME => {
133
series_time_strategy(options.series_length_range.clone()).boxed()
134
},
135
#[cfg(feature = "dtype-datetime")]
136
_ if selection == S::DATETIME => {
137
series_datetime_strategy(options.series_length_range.clone()).boxed()
138
},
139
#[cfg(feature = "dtype-date")]
140
_ if selection == S::DATE => {
141
series_date_strategy(options.series_length_range.clone()).boxed()
142
},
143
#[cfg(feature = "dtype-duration")]
144
_ if selection == S::DURATION => {
145
series_duration_strategy(options.series_length_range.clone()).boxed()
146
},
147
#[cfg(feature = "dtype-decimal")]
148
_ if selection == S::DECIMAL => {
149
series_decimal_strategy(options.series_length_range.clone()).boxed()
150
},
151
#[cfg(feature = "dtype-categorical")]
152
_ if selection == S::CATEGORICAL => series_categorical_strategy(
153
options.series_length_range.clone(),
154
options.categories_range.clone(),
155
)
156
.boxed(),
157
#[cfg(feature = "dtype-categorical")]
158
_ if selection == S::ENUM => series_enum_strategy(
159
options.series_length_range.clone(),
160
options.categories_range.clone(),
161
)
162
.boxed(),
163
_ if selection == S::LIST => series_list_strategy(
164
series_strategy(options.clone(), nesting_level + 1),
165
options.series_length_range.clone(),
166
)
167
.boxed(),
168
#[cfg(feature = "dtype-array")]
169
_ if selection == S::ARRAY => series_array_strategy(
170
series_strategy(options.clone(), nesting_level + 1),
171
options.series_length_range.clone(),
172
)
173
.boxed(),
174
#[cfg(feature = "dtype-struct")]
175
_ if selection == S::STRUCT => series_struct_strategy(
176
series_strategy(options.clone(), nesting_level + 1),
177
options.struct_fields_range.clone(),
178
)
179
.boxed(),
180
_ => unreachable!(),
181
}
182
})
183
}
184
185
fn series_boolean_strategy(
186
series_length_range: RangeInclusive<usize>,
187
) -> impl Strategy<Value = Series> {
188
prop::collection::vec(any::<bool>(), series_length_range)
189
.prop_map(|bools| Series::new(next_column_name().into(), bools))
190
}
191
192
fn series_uint_strategy(
193
series_length_range: RangeInclusive<usize>,
194
) -> impl Strategy<Value = Series> {
195
prop_oneof![
196
prop::collection::vec(any::<u8>(), series_length_range.clone())
197
.prop_map(|uints| Series::new(next_column_name().into(), uints)),
198
prop::collection::vec(any::<u16>(), series_length_range.clone())
199
.prop_map(|uints| Series::new(next_column_name().into(), uints)),
200
prop::collection::vec(any::<u32>(), series_length_range.clone())
201
.prop_map(|uints| Series::new(next_column_name().into(), uints)),
202
prop::collection::vec(any::<u64>(), series_length_range.clone())
203
.prop_map(|uints| Series::new(next_column_name().into(), uints)),
204
prop::collection::vec(any::<u128>(), series_length_range)
205
.prop_map(|uints| Series::new(next_column_name().into(), uints)),
206
]
207
}
208
209
fn series_int_strategy(
210
series_length_range: RangeInclusive<usize>,
211
) -> impl Strategy<Value = Series> {
212
prop_oneof![
213
prop::collection::vec(any::<i8>(), series_length_range.clone())
214
.prop_map(|ints| Series::new(next_column_name().into(), ints)),
215
prop::collection::vec(any::<i16>(), series_length_range.clone())
216
.prop_map(|ints| Series::new(next_column_name().into(), ints)),
217
prop::collection::vec(any::<i32>(), series_length_range.clone())
218
.prop_map(|ints| Series::new(next_column_name().into(), ints)),
219
prop::collection::vec(any::<i64>(), series_length_range.clone())
220
.prop_map(|ints| Series::new(next_column_name().into(), ints)),
221
prop::collection::vec(any::<i128>(), series_length_range)
222
.prop_map(|ints| Series::new(next_column_name().into(), ints)),
223
]
224
}
225
226
fn series_float_strategy(
227
series_length_range: RangeInclusive<usize>,
228
) -> impl Strategy<Value = Series> {
229
prop_oneof![
230
prop::collection::vec(any::<f32>(), series_length_range.clone())
231
.prop_map(|floats| Series::new(next_column_name().into(), floats)),
232
prop::collection::vec(any::<f64>(), series_length_range)
233
.prop_map(|floats| Series::new(next_column_name().into(), floats)),
234
]
235
}
236
237
fn series_string_strategy(
238
series_length_range: RangeInclusive<usize>,
239
) -> impl Strategy<Value = Series> {
240
prop::collection::vec(any::<String>(), series_length_range)
241
.prop_map(|strings| Series::new(next_column_name().into(), strings))
242
}
243
244
fn series_binary_strategy(
245
series_length_range: RangeInclusive<usize>,
246
) -> impl Strategy<Value = Series> {
247
prop::collection::vec(any::<u8>(), series_length_range)
248
.prop_map(|binaries| Series::new(next_column_name().into(), binaries))
249
}
250
251
#[cfg(feature = "dtype-time")]
252
fn series_time_strategy(
253
series_length_range: RangeInclusive<usize>,
254
) -> impl Strategy<Value = Series> {
255
prop::collection::vec(
256
0i64..86_400_000_000_000i64, // Time range: 0 to just under 24 hours in nanoseconds
257
series_length_range,
258
)
259
.prop_map(|times| {
260
Int64Chunked::new(next_column_name().into(), &times)
261
.into_time()
262
.into_series()
263
})
264
}
265
266
#[cfg(feature = "dtype-datetime")]
267
fn series_datetime_strategy(
268
series_length_range: RangeInclusive<usize>,
269
) -> impl Strategy<Value = Series> {
270
prop::collection::vec(
271
0i64..i64::MAX, // Datetime range: 0 (1970-01-01) to i64::MAX in milliseconds since UNIX epoch
272
series_length_range,
273
)
274
.prop_map(|datetimes| {
275
Int64Chunked::new(next_column_name().into(), &datetimes)
276
.into_datetime(TimeUnit::Milliseconds, None)
277
.into_series()
278
})
279
}
280
281
#[cfg(feature = "dtype-date")]
282
fn series_date_strategy(
283
series_length_range: RangeInclusive<usize>,
284
) -> impl Strategy<Value = Series> {
285
prop::collection::vec(
286
0i32..50_000i32, // Date range: 0 (1970-01-01) to ~50,000 days (~137 years, roughly 1970-2107)
287
series_length_range,
288
)
289
.prop_map(|dates| {
290
Int32Chunked::new(next_column_name().into(), &dates)
291
.into_date()
292
.into_series()
293
})
294
}
295
296
#[cfg(feature = "dtype-duration")]
297
fn series_duration_strategy(
298
series_length_range: RangeInclusive<usize>,
299
) -> impl Strategy<Value = Series> {
300
prop::collection::vec(
301
i64::MIN..i64::MAX, // Duration range: full i64 range in milliseconds (can be negative for time differences)
302
series_length_range,
303
)
304
.prop_map(|durations| {
305
Int64Chunked::new(next_column_name().into(), &durations)
306
.into_duration(TimeUnit::Milliseconds)
307
.into_series()
308
})
309
}
310
311
#[cfg(feature = "dtype-decimal")]
312
fn series_decimal_strategy(
313
series_length_range: RangeInclusive<usize>,
314
) -> impl Strategy<Value = Series> {
315
prop::collection::vec(i128::MIN..i128::MAX, series_length_range).prop_map(|decimals| {
316
Int128Chunked::new(next_column_name().into(), &decimals)
317
.into_decimal_unchecked(38, 9) // precision = 38 (max for i128), scale = 9 (9 decimal places)
318
.into_series()
319
})
320
}
321
322
#[cfg(feature = "dtype-categorical")]
323
fn series_categorical_strategy(
324
series_length_range: RangeInclusive<usize>,
325
categories_range: RangeInclusive<usize>,
326
) -> impl Strategy<Value = Series> {
327
categories_range
328
.prop_flat_map(move |n_categories| {
329
let possible_categories: Vec<String> =
330
(0..n_categories).map(|i| format!("category{i}")).collect();
331
332
prop::collection::vec(
333
prop::sample::select(possible_categories),
334
series_length_range.clone(),
335
)
336
})
337
.prop_map(|categories| {
338
// Using Categorical8Type (u8 backing) which supports up to 256 unique categories
339
let mapping = Categories::global().mapping();
340
let mut builder = CategoricalChunkedBuilder::<Categorical8Type>::new(
341
next_column_name().into(),
342
DataType::Categorical(Categories::global(), mapping),
343
);
344
345
for category in categories {
346
builder.append_str(&category).unwrap();
347
}
348
349
builder.finish().into_series()
350
})
351
}
352
353
#[cfg(feature = "dtype-categorical")]
354
fn series_enum_strategy(
355
series_length_range: RangeInclusive<usize>,
356
categories_range: RangeInclusive<usize>,
357
) -> impl Strategy<Value = Series> {
358
categories_range
359
.prop_flat_map(move |n_categories| {
360
let possible_categories: Vec<String> =
361
(0..n_categories).map(|i| format!("category{i}")).collect();
362
363
(
364
Just(possible_categories.clone()),
365
prop::collection::vec(
366
prop::sample::select(possible_categories),
367
series_length_range.clone(),
368
),
369
)
370
})
371
.prop_map(|(possible_categories, sampled_categories)| {
372
let frozen_categories =
373
FrozenCategories::new(possible_categories.iter().map(|s| s.as_str())).unwrap();
374
let mapping = frozen_categories.mapping().clone();
375
376
// Using Categorical8Type (u8 backing) which supports up to 256 unique categories
377
let mut builder = CategoricalChunkedBuilder::<Categorical8Type>::new(
378
next_column_name().into(),
379
DataType::Enum(frozen_categories, mapping),
380
);
381
382
for category in sampled_categories {
383
builder.append_str(&category).unwrap();
384
}
385
386
builder.finish().into_series()
387
})
388
}
389
390
fn series_list_strategy(
391
inner: impl Strategy<Value = Series>,
392
series_length_range: RangeInclusive<usize>,
393
) -> impl Strategy<Value = Series> {
394
inner.prop_flat_map(move |sample_series| {
395
series_length_range.clone().prop_map(move |num_lists| {
396
let mut builder = AnonymousListBuilder::new(
397
next_column_name().into(),
398
num_lists,
399
Some(sample_series.dtype().clone()),
400
);
401
402
for _ in 0..num_lists {
403
builder.append_series(&sample_series).unwrap();
404
}
405
406
builder.finish().into_series()
407
})
408
})
409
}
410
411
#[cfg(feature = "dtype-array")]
412
fn series_array_strategy(
413
inner: impl Strategy<Value = Series>,
414
series_length_range: RangeInclusive<usize>,
415
) -> impl Strategy<Value = Series> {
416
inner.prop_flat_map(move |sample_series| {
417
series_length_range.clone().prop_map(move |num_arrays| {
418
let width = sample_series.len();
419
420
let mut builder = AnonymousListBuilder::new(
421
next_column_name().into(),
422
num_arrays,
423
Some(sample_series.dtype().clone()),
424
);
425
426
for _ in 0..num_arrays {
427
builder.append_series(&sample_series).unwrap();
428
}
429
430
let list_series = builder.finish().into_series();
431
432
list_series
433
.cast(&DataType::Array(
434
Box::new(sample_series.dtype().clone()),
435
width,
436
))
437
.unwrap()
438
})
439
})
440
}
441
442
#[cfg(feature = "dtype-struct")]
443
fn series_struct_strategy(
444
inner: impl Strategy<Value = Series>,
445
struct_fields_range: RangeInclusive<usize>,
446
) -> impl Strategy<Value = Series> {
447
inner.prop_flat_map(move |sample_series| {
448
struct_fields_range.clone().prop_map(move |num_fields| {
449
let length = sample_series.len();
450
451
let fields: Vec<Series> = (0..num_fields)
452
.map(|i| {
453
let mut field = sample_series.clone();
454
field.rename(format!("field_{}", i).into());
455
field
456
})
457
.collect();
458
459
StructChunked::from_series(next_column_name().into(), length, fields.iter())
460
.unwrap()
461
.into_series()
462
})
463
})
464
}
465
466