Path: blob/main/crates/polars-core/src/series/proptest.rs
7884 views
use std::ops::RangeInclusive;1use std::rc::Rc;2use std::sync::Arc;3use std::sync::atomic::{AtomicUsize, Ordering};45use arrow::bitmap::bitmask::nth_set_bit_u32;6#[cfg(feature = "dtype-categorical")]7use polars_dtype::categorical::{CategoricalMapping, Categories, FrozenCategories};8use proptest::prelude::*;910use crate::chunked_array::builder::AnonymousListBuilder;11#[cfg(feature = "dtype-categorical")]12use crate::chunked_array::builder::CategoricalChunkedBuilder;13use crate::prelude::{Int32Chunked, Int64Chunked, Int128Chunked, NamedFrom, Series, TimeUnit};14#[cfg(feature = "dtype-struct")]15use crate::series::StructChunked;16use crate::series::from::IntoSeries;17#[cfg(feature = "dtype-categorical")]18use crate::series::{Categorical8Type, DataType};1920// A global, thread-safe counter that will be used to ensure unique column names when the Series are created21// This is especially useful for when the Series strategies are combined to create a DataFrame strategy22static COUNTER: AtomicUsize = AtomicUsize::new(0);2324fn next_column_name() -> String {25format!("col_{}", COUNTER.fetch_add(1, Ordering::Relaxed))26}2728bitflags::bitflags! {29#[derive(Debug, Clone, Copy, PartialEq, Eq)]30pub struct SeriesArbitrarySelection: u32 {31const BOOLEAN = 1;32const UINT = 1 << 1;33const INT = 1 << 2;34const FLOAT = 1 << 3;35const STRING = 1 << 4;36const BINARY = 1 << 5;3738const TIME = 1 << 6;39const DATETIME = 1 << 7;40const DATE = 1 << 8;41const DURATION = 1 << 9;42const DECIMAL = 1 << 10;43const CATEGORICAL = 1 << 11;44const ENUM = 1 << 12;4546const LIST = 1 << 13;47const ARRAY = 1 << 14;48const STRUCT = 1 << 15;49}50}5152impl SeriesArbitrarySelection {53pub fn physical() -> Self {54Self::BOOLEAN | Self::UINT | Self::INT | Self::FLOAT | Self::STRING | Self::BINARY55}5657pub fn logical() -> Self {58Self::TIME59| Self::DATETIME60| Self::DATE61| Self::DURATION62| Self::DECIMAL63| Self::CATEGORICAL64| Self::ENUM65}6667pub fn nested() -> Self {68Self::LIST | Self::ARRAY | Self::STRUCT69}70}7172#[derive(Clone)]73pub struct SeriesArbitraryOptions {74pub allowed_dtypes: SeriesArbitrarySelection,75pub max_nesting_level: usize,76pub series_length_range: RangeInclusive<usize>,77pub categories_range: RangeInclusive<usize>,78pub struct_fields_range: RangeInclusive<usize>,79}8081impl Default for SeriesArbitraryOptions {82fn default() -> Self {83Self {84allowed_dtypes: SeriesArbitrarySelection::all(),85max_nesting_level: 3,86series_length_range: 0..=5,87categories_range: 0..=3,88struct_fields_range: 0..=3,89}90}91}9293pub fn series_strategy(94options: Rc<SeriesArbitraryOptions>,95nesting_level: usize,96) -> impl Strategy<Value = Series> {97use SeriesArbitrarySelection as S;9899let mut allowed_dtypes = options.allowed_dtypes;100101if options.max_nesting_level <= nesting_level {102allowed_dtypes &= !S::nested()103}104105let num_possible_types = allowed_dtypes.bits().count_ones();106assert!(num_possible_types > 0);107108(0..num_possible_types).prop_flat_map(move |i| {109let selection =110S::from_bits_retain(1 << nth_set_bit_u32(options.allowed_dtypes.bits(), i).unwrap());111112match selection {113_ if selection == S::BOOLEAN => {114series_boolean_strategy(options.series_length_range.clone()).boxed()115},116_ if selection == S::UINT => {117series_uint_strategy(options.series_length_range.clone()).boxed()118},119_ if selection == S::INT => {120series_int_strategy(options.series_length_range.clone()).boxed()121},122_ if selection == S::FLOAT => {123series_float_strategy(options.series_length_range.clone()).boxed()124},125_ if selection == S::STRING => {126series_string_strategy(options.series_length_range.clone()).boxed()127},128_ if selection == S::BINARY => {129series_binary_strategy(options.series_length_range.clone()).boxed()130},131#[cfg(feature = "dtype-time")]132_ if selection == S::TIME => {133series_time_strategy(options.series_length_range.clone()).boxed()134},135#[cfg(feature = "dtype-datetime")]136_ if selection == S::DATETIME => {137series_datetime_strategy(options.series_length_range.clone()).boxed()138},139#[cfg(feature = "dtype-date")]140_ if selection == S::DATE => {141series_date_strategy(options.series_length_range.clone()).boxed()142},143#[cfg(feature = "dtype-duration")]144_ if selection == S::DURATION => {145series_duration_strategy(options.series_length_range.clone()).boxed()146},147#[cfg(feature = "dtype-decimal")]148_ if selection == S::DECIMAL => {149series_decimal_strategy(options.series_length_range.clone()).boxed()150},151#[cfg(feature = "dtype-categorical")]152_ if selection == S::CATEGORICAL => series_categorical_strategy(153options.series_length_range.clone(),154options.categories_range.clone(),155)156.boxed(),157#[cfg(feature = "dtype-categorical")]158_ if selection == S::ENUM => series_enum_strategy(159options.series_length_range.clone(),160options.categories_range.clone(),161)162.boxed(),163_ if selection == S::LIST => series_list_strategy(164series_strategy(options.clone(), nesting_level + 1),165options.series_length_range.clone(),166)167.boxed(),168#[cfg(feature = "dtype-array")]169_ if selection == S::ARRAY => series_array_strategy(170series_strategy(options.clone(), nesting_level + 1),171options.series_length_range.clone(),172)173.boxed(),174#[cfg(feature = "dtype-struct")]175_ if selection == S::STRUCT => series_struct_strategy(176series_strategy(options.clone(), nesting_level + 1),177options.struct_fields_range.clone(),178)179.boxed(),180_ => unreachable!(),181}182})183}184185fn series_boolean_strategy(186series_length_range: RangeInclusive<usize>,187) -> impl Strategy<Value = Series> {188prop::collection::vec(any::<bool>(), series_length_range)189.prop_map(|bools| Series::new(next_column_name().into(), bools))190}191192fn series_uint_strategy(193series_length_range: RangeInclusive<usize>,194) -> impl Strategy<Value = Series> {195prop_oneof![196prop::collection::vec(any::<u8>(), series_length_range.clone())197.prop_map(|uints| Series::new(next_column_name().into(), uints)),198prop::collection::vec(any::<u16>(), series_length_range.clone())199.prop_map(|uints| Series::new(next_column_name().into(), uints)),200prop::collection::vec(any::<u32>(), series_length_range.clone())201.prop_map(|uints| Series::new(next_column_name().into(), uints)),202prop::collection::vec(any::<u64>(), series_length_range.clone())203.prop_map(|uints| Series::new(next_column_name().into(), uints)),204prop::collection::vec(any::<u128>(), series_length_range)205.prop_map(|uints| Series::new(next_column_name().into(), uints)),206]207}208209fn series_int_strategy(210series_length_range: RangeInclusive<usize>,211) -> impl Strategy<Value = Series> {212prop_oneof![213prop::collection::vec(any::<i8>(), series_length_range.clone())214.prop_map(|ints| Series::new(next_column_name().into(), ints)),215prop::collection::vec(any::<i16>(), series_length_range.clone())216.prop_map(|ints| Series::new(next_column_name().into(), ints)),217prop::collection::vec(any::<i32>(), series_length_range.clone())218.prop_map(|ints| Series::new(next_column_name().into(), ints)),219prop::collection::vec(any::<i64>(), series_length_range.clone())220.prop_map(|ints| Series::new(next_column_name().into(), ints)),221prop::collection::vec(any::<i128>(), series_length_range)222.prop_map(|ints| Series::new(next_column_name().into(), ints)),223]224}225226fn series_float_strategy(227series_length_range: RangeInclusive<usize>,228) -> impl Strategy<Value = Series> {229prop_oneof![230prop::collection::vec(any::<f32>(), series_length_range.clone())231.prop_map(|floats| Series::new(next_column_name().into(), floats)),232prop::collection::vec(any::<f64>(), series_length_range)233.prop_map(|floats| Series::new(next_column_name().into(), floats)),234]235}236237fn series_string_strategy(238series_length_range: RangeInclusive<usize>,239) -> impl Strategy<Value = Series> {240prop::collection::vec(any::<String>(), series_length_range)241.prop_map(|strings| Series::new(next_column_name().into(), strings))242}243244fn series_binary_strategy(245series_length_range: RangeInclusive<usize>,246) -> impl Strategy<Value = Series> {247prop::collection::vec(any::<u8>(), series_length_range)248.prop_map(|binaries| Series::new(next_column_name().into(), binaries))249}250251#[cfg(feature = "dtype-time")]252fn series_time_strategy(253series_length_range: RangeInclusive<usize>,254) -> impl Strategy<Value = Series> {255prop::collection::vec(2560i64..86_400_000_000_000i64, // Time range: 0 to just under 24 hours in nanoseconds257series_length_range,258)259.prop_map(|times| {260Int64Chunked::new(next_column_name().into(), ×)261.into_time()262.into_series()263})264}265266#[cfg(feature = "dtype-datetime")]267fn series_datetime_strategy(268series_length_range: RangeInclusive<usize>,269) -> impl Strategy<Value = Series> {270prop::collection::vec(2710i64..i64::MAX, // Datetime range: 0 (1970-01-01) to i64::MAX in milliseconds since UNIX epoch272series_length_range,273)274.prop_map(|datetimes| {275Int64Chunked::new(next_column_name().into(), &datetimes)276.into_datetime(TimeUnit::Milliseconds, None)277.into_series()278})279}280281#[cfg(feature = "dtype-date")]282fn series_date_strategy(283series_length_range: RangeInclusive<usize>,284) -> impl Strategy<Value = Series> {285prop::collection::vec(2860i32..50_000i32, // Date range: 0 (1970-01-01) to ~50,000 days (~137 years, roughly 1970-2107)287series_length_range,288)289.prop_map(|dates| {290Int32Chunked::new(next_column_name().into(), &dates)291.into_date()292.into_series()293})294}295296#[cfg(feature = "dtype-duration")]297fn series_duration_strategy(298series_length_range: RangeInclusive<usize>,299) -> impl Strategy<Value = Series> {300prop::collection::vec(301i64::MIN..i64::MAX, // Duration range: full i64 range in milliseconds (can be negative for time differences)302series_length_range,303)304.prop_map(|durations| {305Int64Chunked::new(next_column_name().into(), &durations)306.into_duration(TimeUnit::Milliseconds)307.into_series()308})309}310311#[cfg(feature = "dtype-decimal")]312fn series_decimal_strategy(313series_length_range: RangeInclusive<usize>,314) -> impl Strategy<Value = Series> {315prop::collection::vec(i128::MIN..i128::MAX, series_length_range).prop_map(|decimals| {316Int128Chunked::new(next_column_name().into(), &decimals)317.into_decimal_unchecked(38, 9) // precision = 38 (max for i128), scale = 9 (9 decimal places)318.into_series()319})320}321322#[cfg(feature = "dtype-categorical")]323fn series_categorical_strategy(324series_length_range: RangeInclusive<usize>,325categories_range: RangeInclusive<usize>,326) -> impl Strategy<Value = Series> {327categories_range328.prop_flat_map(move |n_categories| {329let possible_categories: Vec<String> =330(0..n_categories).map(|i| format!("category{i}")).collect();331332prop::collection::vec(333prop::sample::select(possible_categories),334series_length_range.clone(),335)336})337.prop_map(|categories| {338// Using Categorical8Type (u8 backing) which supports up to 256 unique categories339let mut builder = CategoricalChunkedBuilder::<Categorical8Type>::new(340next_column_name().into(),341DataType::Categorical(Categories::global(), Arc::new(CategoricalMapping::new(256))),342);343344for category in categories {345builder.append_str(&category).unwrap();346}347348builder.finish().into_series()349})350}351352#[cfg(feature = "dtype-categorical")]353fn series_enum_strategy(354series_length_range: RangeInclusive<usize>,355categories_range: RangeInclusive<usize>,356) -> impl Strategy<Value = Series> {357categories_range358.prop_flat_map(move |n_categories| {359let possible_categories: Vec<String> =360(0..n_categories).map(|i| format!("category{i}")).collect();361362(363Just(possible_categories.clone()),364prop::collection::vec(365prop::sample::select(possible_categories),366series_length_range.clone(),367),368)369})370.prop_map(|(possible_categories, sampled_categories)| {371let frozen_categories =372FrozenCategories::new(possible_categories.iter().map(|s| s.as_str())).unwrap();373let mapping = frozen_categories.mapping().clone();374375// Using Categorical8Type (u8 backing) which supports up to 256 unique categories376let mut builder = CategoricalChunkedBuilder::<Categorical8Type>::new(377next_column_name().into(),378DataType::Enum(frozen_categories, mapping),379);380381for category in sampled_categories {382builder.append_str(&category).unwrap();383}384385builder.finish().into_series()386})387}388389fn series_list_strategy(390inner: impl Strategy<Value = Series>,391series_length_range: RangeInclusive<usize>,392) -> impl Strategy<Value = Series> {393inner.prop_flat_map(move |sample_series| {394series_length_range.clone().prop_map(move |num_lists| {395let mut builder = AnonymousListBuilder::new(396next_column_name().into(),397num_lists,398Some(sample_series.dtype().clone()),399);400401for _ in 0..num_lists {402builder.append_series(&sample_series).unwrap();403}404405builder.finish().into_series()406})407})408}409410#[cfg(feature = "dtype-array")]411fn series_array_strategy(412inner: impl Strategy<Value = Series>,413series_length_range: RangeInclusive<usize>,414) -> impl Strategy<Value = Series> {415inner.prop_flat_map(move |sample_series| {416series_length_range.clone().prop_map(move |num_arrays| {417let width = sample_series.len();418419let mut builder = AnonymousListBuilder::new(420next_column_name().into(),421num_arrays,422Some(sample_series.dtype().clone()),423);424425for _ in 0..num_arrays {426builder.append_series(&sample_series).unwrap();427}428429let list_series = builder.finish().into_series();430431list_series432.cast(&DataType::Array(433Box::new(sample_series.dtype().clone()),434width,435))436.unwrap()437})438})439}440441#[cfg(feature = "dtype-struct")]442fn series_struct_strategy(443inner: impl Strategy<Value = Series>,444struct_fields_range: RangeInclusive<usize>,445) -> impl Strategy<Value = Series> {446inner.prop_flat_map(move |sample_series| {447struct_fields_range.clone().prop_map(move |num_fields| {448let length = sample_series.len();449450let fields: Vec<Series> = (0..num_fields)451.map(|i| {452let mut field = sample_series.clone();453field.rename(format!("field_{}", i).into());454field455})456.collect();457458StructChunked::from_series(next_column_name().into(), length, fields.iter())459.unwrap()460.into_series()461})462})463}464465466