Path: blob/main/crates/polars-arrow/src/array/proptest.rs
8396 views
use std::ops::RangeInclusive;1use std::rc::Rc;23use polars_utils::format_pl_smallstr;4use proptest::prelude::{Just, Strategy};5use proptest::sample::SizeRange;67use super::binview::proptest::binview_array;8use super::{9Array, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray, ListArray, NullArray,10StructArray,11};12use crate::array::binview::proptest::utf8view_array;13use crate::array::boolean::proptest::boolean_array;14use crate::array::primitive::proptest::primitive_array;15use crate::array::{PrimitiveArray, Utf8ViewArray};16use crate::bitmap::bitmask::nth_set_bit_u32;17use crate::datatypes::{ArrowDataType, Field};1819bitflags::bitflags! {20#[derive(Debug, Clone, Copy, PartialEq, Eq)]21pub struct ArrowDataTypeArbitrarySelection: u32 {22const NULL = 1;2324const BOOLEAN = 1 << 1;2526const INT8 = 1 << 2;27const INT16 = 1 << 3;28const INT32 = 1 << 4;29const INT64 = 1 << 5;30const INT128 = 1 << 6;3132const UINT8 = 1 << 7;33const UINT16 = 1 << 8;34const UINT32 = 1 << 9;35const UINT64 = 1 << 10;3637const FLOAT32 = 1 << 11;38const FLOAT64 = 1 << 12;3940const STRVIEW = 1 << 13;41const BINVIEW = 1 << 14;42const BINARY = 1 << 15;4344const LIST = 1 << 16;45const FIXED_SIZE_LIST = 1 << 17;46const STRUCT = 1 << 18;47}48}4950impl ArrowDataTypeArbitrarySelection {51pub fn nested() -> Self {52Self::LIST | Self::FIXED_SIZE_LIST | Self::STRUCT53}54}5556#[derive(Clone)]57pub struct ArrowDataTypeArbitraryOptions {58pub allowed_dtypes: ArrowDataTypeArbitrarySelection,5960pub array_width_range: RangeInclusive<usize>,61pub struct_num_fields_range: RangeInclusive<usize>,6263pub max_nesting_level: usize,64}6566#[derive(Clone)]67pub struct ArrayArbitraryOptions {68pub dtype: ArrowDataTypeArbitraryOptions,69}7071impl Default for ArrowDataTypeArbitraryOptions {72fn default() -> Self {73Self {74allowed_dtypes: ArrowDataTypeArbitrarySelection::all(),75array_width_range: 0..=7,76struct_num_fields_range: 0..=7,77max_nesting_level: 5,78}79}80}8182#[allow(clippy::derivable_impls)]83impl Default for ArrayArbitraryOptions {84fn default() -> Self {85Self {86dtype: Default::default(),87}88}89}9091pub fn arrow_data_type_impl(92options: Rc<ArrowDataTypeArbitraryOptions>,93nesting_level: usize,94) -> impl Strategy<Value = ArrowDataType> {95use ArrowDataTypeArbitrarySelection as S;96let mut allowed_dtypes = options.allowed_dtypes;9798if options.max_nesting_level <= nesting_level {99allowed_dtypes &= !S::nested();100}101102let num_possible_types = allowed_dtypes.bits().count_ones();103assert!(num_possible_types > 0);104105(0..num_possible_types).prop_flat_map(move |i| {106let selection =107S::from_bits_retain(1 << nth_set_bit_u32(options.allowed_dtypes.bits(), i).unwrap());108109match selection {110_ if selection == S::NULL => Just(ArrowDataType::Null).boxed(),111_ if selection == S::BOOLEAN => Just(ArrowDataType::Boolean).boxed(),112_ if selection == S::INT8 => Just(ArrowDataType::Int8).boxed(),113_ if selection == S::INT16 => Just(ArrowDataType::Int16).boxed(),114_ if selection == S::INT32 => Just(ArrowDataType::Int32).boxed(),115_ if selection == S::INT64 => Just(ArrowDataType::Int64).boxed(),116_ if selection == S::INT128 => Just(ArrowDataType::Int128).boxed(),117_ if selection == S::UINT8 => Just(ArrowDataType::UInt8).boxed(),118_ if selection == S::UINT16 => Just(ArrowDataType::UInt16).boxed(),119_ if selection == S::UINT32 => Just(ArrowDataType::UInt32).boxed(),120_ if selection == S::UINT64 => Just(ArrowDataType::UInt64).boxed(),121_ if selection == S::FLOAT32 => Just(ArrowDataType::Float32).boxed(),122_ if selection == S::FLOAT64 => Just(ArrowDataType::Float64).boxed(),123_ if selection == S::STRVIEW => Just(ArrowDataType::Utf8View).boxed(),124_ if selection == S::BINVIEW => Just(ArrowDataType::BinaryView).boxed(),125_ if selection == S::BINARY => Just(ArrowDataType::LargeBinary).boxed(),126_ if selection == S::LIST => arrow_data_type_impl(options.clone(), nesting_level + 1)127.prop_map(|dtype| {128let field = Field::new("item".into(), dtype, true);129ArrowDataType::LargeList(Box::new(field))130})131.boxed(),132_ if selection == S::FIXED_SIZE_LIST => (133arrow_data_type_impl(options.clone(), nesting_level + 1),134options.array_width_range.clone(),135)136.prop_map(|(dtype, width)| {137let field = Field::new("item".into(), dtype, true);138ArrowDataType::FixedSizeList(Box::new(field), width)139})140.boxed(),141_ if selection == S::STRUCT => proptest::collection::vec(142arrow_data_type_impl(options.clone(), nesting_level + 1),143options.struct_num_fields_range.clone(),144)145.prop_map(|dtypes| {146let fields = dtypes147.into_iter()148.enumerate()149.map(|(i, dtype)| Field::new(format_pl_smallstr!("f{}", i + 1), dtype, true))150.collect();151ArrowDataType::Struct(fields)152})153.boxed(),154_ => unreachable!(),155}156})157}158159pub fn arrow_data_type(160options: ArrowDataTypeArbitraryOptions,161) -> impl Strategy<Value = ArrowDataType> {162arrow_data_type_impl(Rc::new(options), 0)163}164165pub fn array_with_dtype(166dtype: ArrowDataType,167size_range: impl Into<SizeRange>,168) -> impl Strategy<Value = Box<dyn Array>> {169let size_range = size_range.into();170match dtype {171ArrowDataType::Null => null_array(size_range).prop_map(NullArray::boxed).boxed(),172ArrowDataType::Boolean => boolean_array(size_range)173.prop_map(BooleanArray::boxed)174.boxed(),175ArrowDataType::Int8 => primitive_array::<i8>(size_range)176.prop_map(PrimitiveArray::boxed)177.boxed(),178ArrowDataType::Int16 => primitive_array::<i16>(size_range)179.prop_map(PrimitiveArray::boxed)180.boxed(),181ArrowDataType::Int32 => primitive_array::<i32>(size_range)182.prop_map(PrimitiveArray::boxed)183.boxed(),184ArrowDataType::Int64 => primitive_array::<i64>(size_range)185.prop_map(PrimitiveArray::boxed)186.boxed(),187ArrowDataType::Int128 => primitive_array::<i128>(size_range)188.prop_map(PrimitiveArray::boxed)189.boxed(),190ArrowDataType::UInt8 => primitive_array::<u8>(size_range)191.prop_map(PrimitiveArray::boxed)192.boxed(),193ArrowDataType::UInt16 => primitive_array::<u16>(size_range)194.prop_map(PrimitiveArray::boxed)195.boxed(),196ArrowDataType::UInt32 => primitive_array::<u32>(size_range)197.prop_map(PrimitiveArray::boxed)198.boxed(),199ArrowDataType::UInt64 => primitive_array::<u64>(size_range)200.prop_map(PrimitiveArray::boxed)201.boxed(),202ArrowDataType::UInt128 => primitive_array::<u128>(size_range)203.prop_map(PrimitiveArray::boxed)204.boxed(),205ArrowDataType::Float32 => primitive_array::<f32>(size_range)206.prop_map(PrimitiveArray::boxed)207.boxed(),208ArrowDataType::Float64 => primitive_array::<f64>(size_range)209.prop_map(PrimitiveArray::boxed)210.boxed(),211ArrowDataType::LargeBinary => super::binary::proptest::binary_array(size_range)212.prop_map(BinaryArray::boxed)213.boxed(),214ArrowDataType::FixedSizeList(field, width) => {215super::fixed_size_list::proptest::fixed_size_list_array_with_dtype(216size_range, field, width,217)218.prop_map(FixedSizeListArray::boxed)219.boxed()220},221ArrowDataType::LargeList(field) => {222super::list::proptest::list_array_with_dtype(size_range, field)223.prop_map(ListArray::<i64>::boxed)224.boxed()225},226ArrowDataType::Struct(fields) => {227super::struct_::proptest::struct_array_with_fields(size_range, fields)228.prop_map(StructArray::boxed)229.boxed()230},231ArrowDataType::BinaryView => binview_array(size_range)232.prop_map(BinaryViewArray::boxed)233.boxed(),234ArrowDataType::Utf8View => utf8view_array(size_range)235.prop_map(Utf8ViewArray::boxed)236.boxed(),237ArrowDataType::Float16238| ArrowDataType::Timestamp(..)239| ArrowDataType::Date32240| ArrowDataType::Date64241| ArrowDataType::Time32(..)242| ArrowDataType::Time64(..)243| ArrowDataType::Duration(..)244| ArrowDataType::Interval(..)245| ArrowDataType::Binary246| ArrowDataType::FixedSizeBinary(_)247| ArrowDataType::Utf8248| ArrowDataType::LargeUtf8249| ArrowDataType::List(..)250| ArrowDataType::Map(_, _)251| ArrowDataType::Dictionary(..)252| ArrowDataType::Decimal(..)253| ArrowDataType::Decimal32(..)254| ArrowDataType::Decimal64(..)255| ArrowDataType::Decimal256(..)256| ArrowDataType::Extension(..)257| ArrowDataType::Unknown258| ArrowDataType::Union(..) => unimplemented!(),259}260}261262pub fn array_with_options(263size_range: impl Into<SizeRange>,264options: ArrayArbitraryOptions,265) -> impl Strategy<Value = Box<dyn Array>> {266let size_range = size_range.into();267arrow_data_type(options.dtype)268.prop_flat_map(move |dtype| array_with_dtype(dtype, size_range.clone()))269}270271pub fn array(size_range: impl Into<SizeRange>) -> impl Strategy<Value = Box<dyn Array>> {272let size_range = size_range.into();273arrow_data_type(Default::default())274.prop_flat_map(move |dtype| array_with_dtype(dtype, size_range.clone()))275}276277pub fn null_array(size_range: impl Into<SizeRange>) -> impl Strategy<Value = NullArray> {278let size_range = size_range.into();279let (min, max) = size_range.start_end_incl();280(min..=max).prop_map(|length| NullArray::new(ArrowDataType::Null, length))281}282283284