Path: blob/main/crates/polars-core/src/series/proptest.rs
8430 views
use std::ops::RangeInclusive;1use std::rc::Rc;2use std::sync::atomic::{AtomicUsize, Ordering};34use arrow::bitmap::bitmask::nth_set_bit_u32;5#[cfg(feature = "dtype-categorical")]6use polars_dtype::categorical::{Categories, FrozenCategories};7use proptest::prelude::*;89use crate::chunked_array::builder::AnonymousListBuilder;10#[cfg(feature = "dtype-categorical")]11use crate::chunked_array::builder::CategoricalChunkedBuilder;12use crate::prelude::{Int32Chunked, Int64Chunked, Int128Chunked, NamedFrom, Series, TimeUnit};13#[cfg(feature = "dtype-struct")]14use crate::series::StructChunked;15use crate::series::from::IntoSeries;16#[cfg(feature = "dtype-categorical")]17use crate::series::{Categorical8Type, DataType};1819// A global, thread-safe counter that will be used to ensure unique column names when the Series are created20// This is especially useful for when the Series strategies are combined to create a DataFrame strategy21static COUNTER: AtomicUsize = AtomicUsize::new(0);2223fn next_column_name() -> String {24format!("col_{}", COUNTER.fetch_add(1, Ordering::Relaxed))25}2627bitflags::bitflags! {28#[derive(Debug, Clone, Copy, PartialEq, Eq)]29pub struct SeriesArbitrarySelection: u32 {30const BOOLEAN = 1;31const UINT = 1 << 1;32const INT = 1 << 2;33const FLOAT = 1 << 3;34const STRING = 1 << 4;35const BINARY = 1 << 5;3637const TIME = 1 << 6;38const DATETIME = 1 << 7;39const DATE = 1 << 8;40const DURATION = 1 << 9;41const DECIMAL = 1 << 10;42const CATEGORICAL = 1 << 11;43const ENUM = 1 << 12;4445const LIST = 1 << 13;46const ARRAY = 1 << 14;47const STRUCT = 1 << 15;48}49}5051impl SeriesArbitrarySelection {52pub fn physical() -> Self {53Self::BOOLEAN | Self::UINT | Self::INT | Self::FLOAT | Self::STRING | Self::BINARY54}5556pub fn logical() -> Self {57Self::TIME58| Self::DATETIME59| Self::DATE60| Self::DURATION61| Self::DECIMAL62| Self::CATEGORICAL63| Self::ENUM64}6566pub fn nested() -> Self {67Self::LIST | Self::ARRAY | Self::STRUCT68}69}7071#[derive(Clone)]72pub struct SeriesArbitraryOptions {73pub allowed_dtypes: SeriesArbitrarySelection,74pub max_nesting_level: usize,75pub series_length_range: RangeInclusive<usize>,76pub categories_range: RangeInclusive<usize>,77pub struct_fields_range: RangeInclusive<usize>,78}7980impl Default for SeriesArbitraryOptions {81fn default() -> Self {82Self {83allowed_dtypes: SeriesArbitrarySelection::all(),84max_nesting_level: 3,85series_length_range: 0..=5,86categories_range: 0..=3,87struct_fields_range: 0..=3,88}89}90}9192pub fn series_strategy(93options: Rc<SeriesArbitraryOptions>,94nesting_level: usize,95) -> impl Strategy<Value = Series> {96use SeriesArbitrarySelection as S;9798let mut allowed_dtypes = options.allowed_dtypes;99100if options.max_nesting_level <= nesting_level {101allowed_dtypes &= !S::nested()102}103104let num_possible_types = allowed_dtypes.bits().count_ones();105assert!(num_possible_types > 0);106107(0..num_possible_types).prop_flat_map(move |i| {108let selection =109S::from_bits_retain(1 << nth_set_bit_u32(options.allowed_dtypes.bits(), i).unwrap());110111match selection {112_ if selection == S::BOOLEAN => {113series_boolean_strategy(options.series_length_range.clone()).boxed()114},115_ if selection == S::UINT => {116series_uint_strategy(options.series_length_range.clone()).boxed()117},118_ if selection == S::INT => {119series_int_strategy(options.series_length_range.clone()).boxed()120},121_ if selection == S::FLOAT => {122series_float_strategy(options.series_length_range.clone()).boxed()123},124_ if selection == S::STRING => {125series_string_strategy(options.series_length_range.clone()).boxed()126},127_ if selection == S::BINARY => {128series_binary_strategy(options.series_length_range.clone()).boxed()129},130#[cfg(feature = "dtype-time")]131_ if selection == S::TIME => {132series_time_strategy(options.series_length_range.clone()).boxed()133},134#[cfg(feature = "dtype-datetime")]135_ if selection == S::DATETIME => {136series_datetime_strategy(options.series_length_range.clone()).boxed()137},138#[cfg(feature = "dtype-date")]139_ if selection == S::DATE => {140series_date_strategy(options.series_length_range.clone()).boxed()141},142#[cfg(feature = "dtype-duration")]143_ if selection == S::DURATION => {144series_duration_strategy(options.series_length_range.clone()).boxed()145},146#[cfg(feature = "dtype-decimal")]147_ if selection == S::DECIMAL => {148series_decimal_strategy(options.series_length_range.clone()).boxed()149},150#[cfg(feature = "dtype-categorical")]151_ if selection == S::CATEGORICAL => series_categorical_strategy(152options.series_length_range.clone(),153options.categories_range.clone(),154)155.boxed(),156#[cfg(feature = "dtype-categorical")]157_ if selection == S::ENUM => series_enum_strategy(158options.series_length_range.clone(),159options.categories_range.clone(),160)161.boxed(),162_ if selection == S::LIST => series_list_strategy(163series_strategy(options.clone(), nesting_level + 1),164options.series_length_range.clone(),165)166.boxed(),167#[cfg(feature = "dtype-array")]168_ if selection == S::ARRAY => series_array_strategy(169series_strategy(options.clone(), nesting_level + 1),170options.series_length_range.clone(),171)172.boxed(),173#[cfg(feature = "dtype-struct")]174_ if selection == S::STRUCT => series_struct_strategy(175series_strategy(options.clone(), nesting_level + 1),176options.struct_fields_range.clone(),177)178.boxed(),179_ => unreachable!(),180}181})182}183184fn series_boolean_strategy(185series_length_range: RangeInclusive<usize>,186) -> impl Strategy<Value = Series> {187prop::collection::vec(any::<bool>(), series_length_range)188.prop_map(|bools| Series::new(next_column_name().into(), bools))189}190191fn series_uint_strategy(192series_length_range: RangeInclusive<usize>,193) -> impl Strategy<Value = Series> {194prop_oneof![195prop::collection::vec(any::<u8>(), series_length_range.clone())196.prop_map(|uints| Series::new(next_column_name().into(), uints)),197prop::collection::vec(any::<u16>(), series_length_range.clone())198.prop_map(|uints| Series::new(next_column_name().into(), uints)),199prop::collection::vec(any::<u32>(), series_length_range.clone())200.prop_map(|uints| Series::new(next_column_name().into(), uints)),201prop::collection::vec(any::<u64>(), series_length_range.clone())202.prop_map(|uints| Series::new(next_column_name().into(), uints)),203prop::collection::vec(any::<u128>(), series_length_range)204.prop_map(|uints| Series::new(next_column_name().into(), uints)),205]206}207208fn series_int_strategy(209series_length_range: RangeInclusive<usize>,210) -> impl Strategy<Value = Series> {211prop_oneof![212prop::collection::vec(any::<i8>(), series_length_range.clone())213.prop_map(|ints| Series::new(next_column_name().into(), ints)),214prop::collection::vec(any::<i16>(), series_length_range.clone())215.prop_map(|ints| Series::new(next_column_name().into(), ints)),216prop::collection::vec(any::<i32>(), series_length_range.clone())217.prop_map(|ints| Series::new(next_column_name().into(), ints)),218prop::collection::vec(any::<i64>(), series_length_range.clone())219.prop_map(|ints| Series::new(next_column_name().into(), ints)),220prop::collection::vec(any::<i128>(), series_length_range)221.prop_map(|ints| Series::new(next_column_name().into(), ints)),222]223}224225fn series_float_strategy(226series_length_range: RangeInclusive<usize>,227) -> impl Strategy<Value = Series> {228prop_oneof![229prop::collection::vec(any::<f32>(), series_length_range.clone())230.prop_map(|floats| Series::new(next_column_name().into(), floats)),231prop::collection::vec(any::<f64>(), series_length_range)232.prop_map(|floats| Series::new(next_column_name().into(), floats)),233]234}235236fn series_string_strategy(237series_length_range: RangeInclusive<usize>,238) -> impl Strategy<Value = Series> {239prop::collection::vec(any::<String>(), series_length_range)240.prop_map(|strings| Series::new(next_column_name().into(), strings))241}242243fn series_binary_strategy(244series_length_range: RangeInclusive<usize>,245) -> impl Strategy<Value = Series> {246prop::collection::vec(any::<u8>(), series_length_range)247.prop_map(|binaries| Series::new(next_column_name().into(), binaries))248}249250#[cfg(feature = "dtype-time")]251fn series_time_strategy(252series_length_range: RangeInclusive<usize>,253) -> impl Strategy<Value = Series> {254prop::collection::vec(2550i64..86_400_000_000_000i64, // Time range: 0 to just under 24 hours in nanoseconds256series_length_range,257)258.prop_map(|times| {259Int64Chunked::new(next_column_name().into(), ×)260.into_time()261.into_series()262})263}264265#[cfg(feature = "dtype-datetime")]266fn series_datetime_strategy(267series_length_range: RangeInclusive<usize>,268) -> impl Strategy<Value = Series> {269prop::collection::vec(2700i64..i64::MAX, // Datetime range: 0 (1970-01-01) to i64::MAX in milliseconds since UNIX epoch271series_length_range,272)273.prop_map(|datetimes| {274Int64Chunked::new(next_column_name().into(), &datetimes)275.into_datetime(TimeUnit::Milliseconds, None)276.into_series()277})278}279280#[cfg(feature = "dtype-date")]281fn series_date_strategy(282series_length_range: RangeInclusive<usize>,283) -> impl Strategy<Value = Series> {284prop::collection::vec(2850i32..50_000i32, // Date range: 0 (1970-01-01) to ~50,000 days (~137 years, roughly 1970-2107)286series_length_range,287)288.prop_map(|dates| {289Int32Chunked::new(next_column_name().into(), &dates)290.into_date()291.into_series()292})293}294295#[cfg(feature = "dtype-duration")]296fn series_duration_strategy(297series_length_range: RangeInclusive<usize>,298) -> impl Strategy<Value = Series> {299prop::collection::vec(300i64::MIN..i64::MAX, // Duration range: full i64 range in milliseconds (can be negative for time differences)301series_length_range,302)303.prop_map(|durations| {304Int64Chunked::new(next_column_name().into(), &durations)305.into_duration(TimeUnit::Milliseconds)306.into_series()307})308}309310#[cfg(feature = "dtype-decimal")]311fn series_decimal_strategy(312series_length_range: RangeInclusive<usize>,313) -> impl Strategy<Value = Series> {314prop::collection::vec(i128::MIN..i128::MAX, series_length_range).prop_map(|decimals| {315Int128Chunked::new(next_column_name().into(), &decimals)316.into_decimal_unchecked(38, 9) // precision = 38 (max for i128), scale = 9 (9 decimal places)317.into_series()318})319}320321#[cfg(feature = "dtype-categorical")]322fn series_categorical_strategy(323series_length_range: RangeInclusive<usize>,324categories_range: RangeInclusive<usize>,325) -> impl Strategy<Value = Series> {326categories_range327.prop_flat_map(move |n_categories| {328let possible_categories: Vec<String> =329(0..n_categories).map(|i| format!("category{i}")).collect();330331prop::collection::vec(332prop::sample::select(possible_categories),333series_length_range.clone(),334)335})336.prop_map(|categories| {337// Using Categorical8Type (u8 backing) which supports up to 256 unique categories338let mapping = Categories::global().mapping();339let mut builder = CategoricalChunkedBuilder::<Categorical8Type>::new(340next_column_name().into(),341DataType::Categorical(Categories::global(), mapping),342);343344for category in categories {345builder.append_str(&category).unwrap();346}347348builder.finish().into_series()349})350}351352#[cfg(feature = "dtype-categorical")]353fn series_enum_strategy(354series_length_range: RangeInclusive<usize>,355categories_range: RangeInclusive<usize>,356) -> impl Strategy<Value = Series> {357categories_range358.prop_flat_map(move |n_categories| {359let possible_categories: Vec<String> =360(0..n_categories).map(|i| format!("category{i}")).collect();361362(363Just(possible_categories.clone()),364prop::collection::vec(365prop::sample::select(possible_categories),366series_length_range.clone(),367),368)369})370.prop_map(|(possible_categories, sampled_categories)| {371let frozen_categories =372FrozenCategories::new(possible_categories.iter().map(|s| s.as_str())).unwrap();373let mapping = frozen_categories.mapping().clone();374375// Using Categorical8Type (u8 backing) which supports up to 256 unique categories376let mut builder = CategoricalChunkedBuilder::<Categorical8Type>::new(377next_column_name().into(),378DataType::Enum(frozen_categories, mapping),379);380381for category in sampled_categories {382builder.append_str(&category).unwrap();383}384385builder.finish().into_series()386})387}388389fn series_list_strategy(390inner: impl Strategy<Value = Series>,391series_length_range: RangeInclusive<usize>,392) -> impl Strategy<Value = Series> {393inner.prop_flat_map(move |sample_series| {394series_length_range.clone().prop_map(move |num_lists| {395let mut builder = AnonymousListBuilder::new(396next_column_name().into(),397num_lists,398Some(sample_series.dtype().clone()),399);400401for _ in 0..num_lists {402builder.append_series(&sample_series).unwrap();403}404405builder.finish().into_series()406})407})408}409410#[cfg(feature = "dtype-array")]411fn series_array_strategy(412inner: impl Strategy<Value = Series>,413series_length_range: RangeInclusive<usize>,414) -> impl Strategy<Value = Series> {415inner.prop_flat_map(move |sample_series| {416series_length_range.clone().prop_map(move |num_arrays| {417let width = sample_series.len();418419let mut builder = AnonymousListBuilder::new(420next_column_name().into(),421num_arrays,422Some(sample_series.dtype().clone()),423);424425for _ in 0..num_arrays {426builder.append_series(&sample_series).unwrap();427}428429let list_series = builder.finish().into_series();430431list_series432.cast(&DataType::Array(433Box::new(sample_series.dtype().clone()),434width,435))436.unwrap()437})438})439}440441#[cfg(feature = "dtype-struct")]442fn series_struct_strategy(443inner: impl Strategy<Value = Series>,444struct_fields_range: RangeInclusive<usize>,445) -> impl Strategy<Value = Series> {446inner.prop_flat_map(move |sample_series| {447struct_fields_range.clone().prop_map(move |num_fields| {448let length = sample_series.len();449450let fields: Vec<Series> = (0..num_fields)451.map(|i| {452let mut field = sample_series.clone();453field.rename(format!("field_{}", i).into());454field455})456.collect();457458StructChunked::from_series(next_column_name().into(), length, fields.iter())459.unwrap()460.into_series()461})462})463}464465466