Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/array/proptest.rs
6939 views
1
use std::ops::RangeInclusive;
2
use std::rc::Rc;
3
4
use polars_utils::format_pl_smallstr;
5
use proptest::prelude::{Just, Strategy};
6
use proptest::sample::SizeRange;
7
8
use super::binview::proptest::binview_array;
9
use super::{
10
Array, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray, ListArray, NullArray,
11
StructArray,
12
};
13
use crate::array::binview::proptest::utf8view_array;
14
use crate::array::boolean::proptest::boolean_array;
15
use crate::array::primitive::proptest::primitive_array;
16
use crate::array::{PrimitiveArray, Utf8ViewArray};
17
use crate::bitmap::bitmask::nth_set_bit_u32;
18
use crate::datatypes::{ArrowDataType, Field};
19
20
bitflags::bitflags! {
21
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22
pub struct ArrowDataTypeArbitrarySelection: u32 {
23
const NULL = 1;
24
25
const BOOLEAN = 1 << 1;
26
27
const INT8 = 1 << 2;
28
const INT16 = 1 << 3;
29
const INT32 = 1 << 4;
30
const INT64 = 1 << 5;
31
const INT128 = 1 << 6;
32
33
const UINT8 = 1 << 7;
34
const UINT16 = 1 << 8;
35
const UINT32 = 1 << 9;
36
const UINT64 = 1 << 10;
37
38
const FLOAT32 = 1 << 11;
39
const FLOAT64 = 1 << 12;
40
41
const STRVIEW = 1 << 13;
42
const BINVIEW = 1 << 14;
43
const BINARY = 1 << 15;
44
45
const LIST = 1 << 16;
46
const FIXED_SIZE_LIST = 1 << 17;
47
const STRUCT = 1 << 18;
48
}
49
}
50
51
impl ArrowDataTypeArbitrarySelection {
52
pub fn nested() -> Self {
53
Self::LIST | Self::FIXED_SIZE_LIST | Self::STRUCT
54
}
55
}
56
57
#[derive(Clone)]
58
pub struct ArrowDataTypeArbitraryOptions {
59
pub allowed_dtypes: ArrowDataTypeArbitrarySelection,
60
61
pub array_width_range: RangeInclusive<usize>,
62
pub struct_num_fields_range: RangeInclusive<usize>,
63
64
pub max_nesting_level: usize,
65
}
66
67
#[derive(Clone)]
68
pub struct ArrayArbitraryOptions {
69
pub dtype: ArrowDataTypeArbitraryOptions,
70
}
71
72
impl Default for ArrowDataTypeArbitraryOptions {
73
fn default() -> Self {
74
Self {
75
allowed_dtypes: ArrowDataTypeArbitrarySelection::all(),
76
array_width_range: 0..=7,
77
struct_num_fields_range: 0..=7,
78
max_nesting_level: 5,
79
}
80
}
81
}
82
83
#[allow(clippy::derivable_impls)]
84
impl Default for ArrayArbitraryOptions {
85
fn default() -> Self {
86
Self {
87
dtype: Default::default(),
88
}
89
}
90
}
91
92
pub fn arrow_data_type_impl(
93
options: Rc<ArrowDataTypeArbitraryOptions>,
94
nesting_level: usize,
95
) -> impl Strategy<Value = ArrowDataType> {
96
use ArrowDataTypeArbitrarySelection as S;
97
let mut allowed_dtypes = options.allowed_dtypes;
98
99
if options.max_nesting_level <= nesting_level {
100
allowed_dtypes &= !S::nested();
101
}
102
103
let num_possible_types = allowed_dtypes.bits().count_ones();
104
assert!(num_possible_types > 0);
105
106
(0..num_possible_types).prop_flat_map(move |i| {
107
let selection =
108
S::from_bits_retain(1 << nth_set_bit_u32(options.allowed_dtypes.bits(), i).unwrap());
109
110
match selection {
111
_ if selection == S::NULL => Just(ArrowDataType::Null).boxed(),
112
_ if selection == S::BOOLEAN => Just(ArrowDataType::Boolean).boxed(),
113
_ if selection == S::INT8 => Just(ArrowDataType::Int8).boxed(),
114
_ if selection == S::INT16 => Just(ArrowDataType::Int16).boxed(),
115
_ if selection == S::INT32 => Just(ArrowDataType::Int32).boxed(),
116
_ if selection == S::INT64 => Just(ArrowDataType::Int64).boxed(),
117
_ if selection == S::INT128 => Just(ArrowDataType::Int128).boxed(),
118
_ if selection == S::UINT8 => Just(ArrowDataType::UInt8).boxed(),
119
_ if selection == S::UINT16 => Just(ArrowDataType::UInt16).boxed(),
120
_ if selection == S::UINT32 => Just(ArrowDataType::UInt32).boxed(),
121
_ if selection == S::UINT64 => Just(ArrowDataType::UInt64).boxed(),
122
_ if selection == S::FLOAT32 => Just(ArrowDataType::Float32).boxed(),
123
_ if selection == S::FLOAT64 => Just(ArrowDataType::Float64).boxed(),
124
_ if selection == S::STRVIEW => Just(ArrowDataType::Utf8View).boxed(),
125
_ if selection == S::BINVIEW => Just(ArrowDataType::BinaryView).boxed(),
126
_ if selection == S::BINARY => Just(ArrowDataType::LargeBinary).boxed(),
127
_ if selection == S::LIST => arrow_data_type_impl(options.clone(), nesting_level + 1)
128
.prop_map(|dtype| {
129
let field = Field::new("item".into(), dtype, true);
130
ArrowDataType::LargeList(Box::new(field))
131
})
132
.boxed(),
133
_ if selection == S::FIXED_SIZE_LIST => (
134
arrow_data_type_impl(options.clone(), nesting_level + 1),
135
options.array_width_range.clone(),
136
)
137
.prop_map(|(dtype, width)| {
138
let field = Field::new("item".into(), dtype, true);
139
ArrowDataType::FixedSizeList(Box::new(field), width)
140
})
141
.boxed(),
142
_ if selection == S::STRUCT => proptest::collection::vec(
143
arrow_data_type_impl(options.clone(), nesting_level + 1),
144
options.struct_num_fields_range.clone(),
145
)
146
.prop_map(|dtypes| {
147
let fields = dtypes
148
.into_iter()
149
.enumerate()
150
.map(|(i, dtype)| Field::new(format_pl_smallstr!("f{}", i + 1), dtype, true))
151
.collect();
152
ArrowDataType::Struct(fields)
153
})
154
.boxed(),
155
_ => unreachable!(),
156
}
157
})
158
}
159
160
pub fn arrow_data_type(
161
options: ArrowDataTypeArbitraryOptions,
162
) -> impl Strategy<Value = ArrowDataType> {
163
arrow_data_type_impl(Rc::new(options), 0)
164
}
165
166
pub fn array_with_dtype(
167
dtype: ArrowDataType,
168
size_range: impl Into<SizeRange>,
169
) -> impl Strategy<Value = Box<dyn Array>> {
170
let size_range = size_range.into();
171
match dtype {
172
ArrowDataType::Null => null_array(size_range).prop_map(NullArray::boxed).boxed(),
173
ArrowDataType::Boolean => boolean_array(size_range)
174
.prop_map(BooleanArray::boxed)
175
.boxed(),
176
ArrowDataType::Int8 => primitive_array::<i8>(size_range)
177
.prop_map(PrimitiveArray::boxed)
178
.boxed(),
179
ArrowDataType::Int16 => primitive_array::<i16>(size_range)
180
.prop_map(PrimitiveArray::boxed)
181
.boxed(),
182
ArrowDataType::Int32 => primitive_array::<i32>(size_range)
183
.prop_map(PrimitiveArray::boxed)
184
.boxed(),
185
ArrowDataType::Int64 => primitive_array::<i64>(size_range)
186
.prop_map(PrimitiveArray::boxed)
187
.boxed(),
188
ArrowDataType::Int128 => primitive_array::<i128>(size_range)
189
.prop_map(PrimitiveArray::boxed)
190
.boxed(),
191
ArrowDataType::UInt8 => primitive_array::<u8>(size_range)
192
.prop_map(PrimitiveArray::boxed)
193
.boxed(),
194
ArrowDataType::UInt16 => primitive_array::<u16>(size_range)
195
.prop_map(PrimitiveArray::boxed)
196
.boxed(),
197
ArrowDataType::UInt32 => primitive_array::<u32>(size_range)
198
.prop_map(PrimitiveArray::boxed)
199
.boxed(),
200
ArrowDataType::UInt64 => primitive_array::<u64>(size_range)
201
.prop_map(PrimitiveArray::boxed)
202
.boxed(),
203
ArrowDataType::Float32 => primitive_array::<f32>(size_range)
204
.prop_map(PrimitiveArray::boxed)
205
.boxed(),
206
ArrowDataType::Float64 => primitive_array::<f64>(size_range)
207
.prop_map(PrimitiveArray::boxed)
208
.boxed(),
209
ArrowDataType::LargeBinary => super::binary::proptest::binary_array(size_range)
210
.prop_map(BinaryArray::boxed)
211
.boxed(),
212
ArrowDataType::FixedSizeList(field, width) => {
213
super::fixed_size_list::proptest::fixed_size_list_array_with_dtype(
214
size_range, field, width,
215
)
216
.prop_map(FixedSizeListArray::boxed)
217
.boxed()
218
},
219
ArrowDataType::LargeList(field) => {
220
super::list::proptest::list_array_with_dtype(size_range, field)
221
.prop_map(ListArray::<i64>::boxed)
222
.boxed()
223
},
224
ArrowDataType::Struct(fields) => {
225
super::struct_::proptest::struct_array_with_fields(size_range, fields)
226
.prop_map(StructArray::boxed)
227
.boxed()
228
},
229
ArrowDataType::BinaryView => binview_array(size_range)
230
.prop_map(BinaryViewArray::boxed)
231
.boxed(),
232
ArrowDataType::Utf8View => utf8view_array(size_range)
233
.prop_map(Utf8ViewArray::boxed)
234
.boxed(),
235
ArrowDataType::Float16
236
| ArrowDataType::Timestamp(..)
237
| ArrowDataType::Date32
238
| ArrowDataType::Date64
239
| ArrowDataType::Time32(..)
240
| ArrowDataType::Time64(..)
241
| ArrowDataType::Duration(..)
242
| ArrowDataType::Interval(..)
243
| ArrowDataType::Binary
244
| ArrowDataType::FixedSizeBinary(_)
245
| ArrowDataType::Utf8
246
| ArrowDataType::LargeUtf8
247
| ArrowDataType::List(..)
248
| ArrowDataType::Map(_, _)
249
| ArrowDataType::Dictionary(..)
250
| ArrowDataType::Decimal(..)
251
| ArrowDataType::Decimal32(..)
252
| ArrowDataType::Decimal64(..)
253
| ArrowDataType::Decimal256(..)
254
| ArrowDataType::Extension(..)
255
| ArrowDataType::Unknown
256
| ArrowDataType::Union(..) => unimplemented!(),
257
}
258
}
259
260
pub fn array_with_options(
261
size_range: impl Into<SizeRange>,
262
options: ArrayArbitraryOptions,
263
) -> impl Strategy<Value = Box<dyn Array>> {
264
let size_range = size_range.into();
265
arrow_data_type(options.dtype)
266
.prop_flat_map(move |dtype| array_with_dtype(dtype, size_range.clone()))
267
}
268
269
pub fn array(size_range: impl Into<SizeRange>) -> impl Strategy<Value = Box<dyn Array>> {
270
let size_range = size_range.into();
271
arrow_data_type(Default::default())
272
.prop_flat_map(move |dtype| array_with_dtype(dtype, size_range.clone()))
273
}
274
275
pub fn null_array(size_range: impl Into<SizeRange>) -> impl Strategy<Value = NullArray> {
276
let size_range = size_range.into();
277
let (min, max) = size_range.start_end_incl();
278
(min..=max).prop_map(|length| NullArray::new(ArrowDataType::Null, length))
279
}
280
281