Path: blob/main/crates/polars-arrow/src/array/proptest.rs
6939 views
use std::ops::RangeInclusive;1use std::rc::Rc;23use polars_utils::format_pl_smallstr;4use proptest::prelude::{Just, Strategy};5use proptest::sample::SizeRange;67use super::binview::proptest::binview_array;8use super::{9Array, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray, ListArray, NullArray,10StructArray,11};12use crate::array::binview::proptest::utf8view_array;13use crate::array::boolean::proptest::boolean_array;14use crate::array::primitive::proptest::primitive_array;15use crate::array::{PrimitiveArray, Utf8ViewArray};16use crate::bitmap::bitmask::nth_set_bit_u32;17use crate::datatypes::{ArrowDataType, Field};1819bitflags::bitflags! {20#[derive(Debug, Clone, Copy, PartialEq, Eq)]21pub struct ArrowDataTypeArbitrarySelection: u32 {22const NULL = 1;2324const BOOLEAN = 1 << 1;2526const INT8 = 1 << 2;27const INT16 = 1 << 3;28const INT32 = 1 << 4;29const INT64 = 1 << 5;30const INT128 = 1 << 6;3132const UINT8 = 1 << 7;33const UINT16 = 1 << 8;34const UINT32 = 1 << 9;35const UINT64 = 1 << 10;3637const FLOAT32 = 1 << 11;38const FLOAT64 = 1 << 12;3940const STRVIEW = 1 << 13;41const BINVIEW = 1 << 14;42const BINARY = 1 << 15;4344const LIST = 1 << 16;45const FIXED_SIZE_LIST = 1 << 17;46const STRUCT = 1 << 18;47}48}4950impl ArrowDataTypeArbitrarySelection {51pub fn nested() -> Self {52Self::LIST | Self::FIXED_SIZE_LIST | Self::STRUCT53}54}5556#[derive(Clone)]57pub struct ArrowDataTypeArbitraryOptions {58pub allowed_dtypes: ArrowDataTypeArbitrarySelection,5960pub array_width_range: RangeInclusive<usize>,61pub struct_num_fields_range: RangeInclusive<usize>,6263pub max_nesting_level: usize,64}6566#[derive(Clone)]67pub struct ArrayArbitraryOptions {68pub dtype: ArrowDataTypeArbitraryOptions,69}7071impl Default for ArrowDataTypeArbitraryOptions {72fn default() -> Self {73Self {74allowed_dtypes: ArrowDataTypeArbitrarySelection::all(),75array_width_range: 0..=7,76struct_num_fields_range: 0..=7,77max_nesting_level: 5,78}79}80}8182#[allow(clippy::derivable_impls)]83impl Default for ArrayArbitraryOptions {84fn default() -> Self {85Self {86dtype: Default::default(),87}88}89}9091pub fn arrow_data_type_impl(92options: Rc<ArrowDataTypeArbitraryOptions>,93nesting_level: usize,94) -> impl Strategy<Value = ArrowDataType> {95use ArrowDataTypeArbitrarySelection as S;96let mut allowed_dtypes = options.allowed_dtypes;9798if options.max_nesting_level <= nesting_level {99allowed_dtypes &= !S::nested();100}101102let num_possible_types = allowed_dtypes.bits().count_ones();103assert!(num_possible_types > 0);104105(0..num_possible_types).prop_flat_map(move |i| {106let selection =107S::from_bits_retain(1 << nth_set_bit_u32(options.allowed_dtypes.bits(), i).unwrap());108109match selection {110_ if selection == S::NULL => Just(ArrowDataType::Null).boxed(),111_ if selection == S::BOOLEAN => Just(ArrowDataType::Boolean).boxed(),112_ if selection == S::INT8 => Just(ArrowDataType::Int8).boxed(),113_ if selection == S::INT16 => Just(ArrowDataType::Int16).boxed(),114_ if selection == S::INT32 => Just(ArrowDataType::Int32).boxed(),115_ if selection == S::INT64 => Just(ArrowDataType::Int64).boxed(),116_ if selection == S::INT128 => Just(ArrowDataType::Int128).boxed(),117_ if selection == S::UINT8 => Just(ArrowDataType::UInt8).boxed(),118_ if selection == S::UINT16 => Just(ArrowDataType::UInt16).boxed(),119_ if selection == S::UINT32 => Just(ArrowDataType::UInt32).boxed(),120_ if selection == S::UINT64 => Just(ArrowDataType::UInt64).boxed(),121_ if selection == S::FLOAT32 => Just(ArrowDataType::Float32).boxed(),122_ if selection == S::FLOAT64 => Just(ArrowDataType::Float64).boxed(),123_ if selection == S::STRVIEW => Just(ArrowDataType::Utf8View).boxed(),124_ if selection == S::BINVIEW => Just(ArrowDataType::BinaryView).boxed(),125_ if selection == S::BINARY => Just(ArrowDataType::LargeBinary).boxed(),126_ if selection == S::LIST => arrow_data_type_impl(options.clone(), nesting_level + 1)127.prop_map(|dtype| {128let field = Field::new("item".into(), dtype, true);129ArrowDataType::LargeList(Box::new(field))130})131.boxed(),132_ if selection == S::FIXED_SIZE_LIST => (133arrow_data_type_impl(options.clone(), nesting_level + 1),134options.array_width_range.clone(),135)136.prop_map(|(dtype, width)| {137let field = Field::new("item".into(), dtype, true);138ArrowDataType::FixedSizeList(Box::new(field), width)139})140.boxed(),141_ if selection == S::STRUCT => proptest::collection::vec(142arrow_data_type_impl(options.clone(), nesting_level + 1),143options.struct_num_fields_range.clone(),144)145.prop_map(|dtypes| {146let fields = dtypes147.into_iter()148.enumerate()149.map(|(i, dtype)| Field::new(format_pl_smallstr!("f{}", i + 1), dtype, true))150.collect();151ArrowDataType::Struct(fields)152})153.boxed(),154_ => unreachable!(),155}156})157}158159pub fn arrow_data_type(160options: ArrowDataTypeArbitraryOptions,161) -> impl Strategy<Value = ArrowDataType> {162arrow_data_type_impl(Rc::new(options), 0)163}164165pub fn array_with_dtype(166dtype: ArrowDataType,167size_range: impl Into<SizeRange>,168) -> impl Strategy<Value = Box<dyn Array>> {169let size_range = size_range.into();170match dtype {171ArrowDataType::Null => null_array(size_range).prop_map(NullArray::boxed).boxed(),172ArrowDataType::Boolean => boolean_array(size_range)173.prop_map(BooleanArray::boxed)174.boxed(),175ArrowDataType::Int8 => primitive_array::<i8>(size_range)176.prop_map(PrimitiveArray::boxed)177.boxed(),178ArrowDataType::Int16 => primitive_array::<i16>(size_range)179.prop_map(PrimitiveArray::boxed)180.boxed(),181ArrowDataType::Int32 => primitive_array::<i32>(size_range)182.prop_map(PrimitiveArray::boxed)183.boxed(),184ArrowDataType::Int64 => primitive_array::<i64>(size_range)185.prop_map(PrimitiveArray::boxed)186.boxed(),187ArrowDataType::Int128 => primitive_array::<i128>(size_range)188.prop_map(PrimitiveArray::boxed)189.boxed(),190ArrowDataType::UInt8 => primitive_array::<u8>(size_range)191.prop_map(PrimitiveArray::boxed)192.boxed(),193ArrowDataType::UInt16 => primitive_array::<u16>(size_range)194.prop_map(PrimitiveArray::boxed)195.boxed(),196ArrowDataType::UInt32 => primitive_array::<u32>(size_range)197.prop_map(PrimitiveArray::boxed)198.boxed(),199ArrowDataType::UInt64 => primitive_array::<u64>(size_range)200.prop_map(PrimitiveArray::boxed)201.boxed(),202ArrowDataType::Float32 => primitive_array::<f32>(size_range)203.prop_map(PrimitiveArray::boxed)204.boxed(),205ArrowDataType::Float64 => primitive_array::<f64>(size_range)206.prop_map(PrimitiveArray::boxed)207.boxed(),208ArrowDataType::LargeBinary => super::binary::proptest::binary_array(size_range)209.prop_map(BinaryArray::boxed)210.boxed(),211ArrowDataType::FixedSizeList(field, width) => {212super::fixed_size_list::proptest::fixed_size_list_array_with_dtype(213size_range, field, width,214)215.prop_map(FixedSizeListArray::boxed)216.boxed()217},218ArrowDataType::LargeList(field) => {219super::list::proptest::list_array_with_dtype(size_range, field)220.prop_map(ListArray::<i64>::boxed)221.boxed()222},223ArrowDataType::Struct(fields) => {224super::struct_::proptest::struct_array_with_fields(size_range, fields)225.prop_map(StructArray::boxed)226.boxed()227},228ArrowDataType::BinaryView => binview_array(size_range)229.prop_map(BinaryViewArray::boxed)230.boxed(),231ArrowDataType::Utf8View => utf8view_array(size_range)232.prop_map(Utf8ViewArray::boxed)233.boxed(),234ArrowDataType::Float16235| ArrowDataType::Timestamp(..)236| ArrowDataType::Date32237| ArrowDataType::Date64238| ArrowDataType::Time32(..)239| ArrowDataType::Time64(..)240| ArrowDataType::Duration(..)241| ArrowDataType::Interval(..)242| ArrowDataType::Binary243| ArrowDataType::FixedSizeBinary(_)244| ArrowDataType::Utf8245| ArrowDataType::LargeUtf8246| ArrowDataType::List(..)247| ArrowDataType::Map(_, _)248| ArrowDataType::Dictionary(..)249| ArrowDataType::Decimal(..)250| ArrowDataType::Decimal32(..)251| ArrowDataType::Decimal64(..)252| ArrowDataType::Decimal256(..)253| ArrowDataType::Extension(..)254| ArrowDataType::Unknown255| ArrowDataType::Union(..) => unimplemented!(),256}257}258259pub fn array_with_options(260size_range: impl Into<SizeRange>,261options: ArrayArbitraryOptions,262) -> impl Strategy<Value = Box<dyn Array>> {263let size_range = size_range.into();264arrow_data_type(options.dtype)265.prop_flat_map(move |dtype| array_with_dtype(dtype, size_range.clone()))266}267268pub fn array(size_range: impl Into<SizeRange>) -> impl Strategy<Value = Box<dyn Array>> {269let size_range = size_range.into();270arrow_data_type(Default::default())271.prop_flat_map(move |dtype| array_with_dtype(dtype, size_range.clone()))272}273274pub fn null_array(size_range: impl Into<SizeRange>) -> impl Strategy<Value = NullArray> {275let size_range = size_range.into();276let (min, max) = size_range.start_end_incl();277(min..=max).prop_map(|length| NullArray::new(ArrowDataType::Null, length))278}279280281