Path: blob/main/crates/polars-expr/src/reduce/var_std.rs
8420 views
use std::marker::PhantomData;12use num_traits::AsPrimitive;3use polars_compute::moment::VarState;4use polars_core::with_match_physical_numeric_polars_type;56use super::*;78pub fn new_var_std_reduction(9dtype: DataType,10is_std: bool,11ddof: u8,12) -> PolarsResult<Box<dyn GroupedReduction>> {13// TODO: Move the error checks up and make this function infallible14use DataType::*;15use VecGroupedReduction as VGR;16let op_name = if is_std { "std" } else { "var" };17Ok(match dtype {18Boolean => Box::new(VGR::new(dtype, BoolVarStdReducer { is_std, ddof })),19_ if dtype.is_primitive_numeric() => {20with_match_physical_numeric_polars_type!(dtype.to_physical(), |$T| {21Box::new(VGR::new(dtype, VarStdReducer::<$T> {22is_std,23ddof,24needs_cast: false,25_phantom: PhantomData,26}))27})28},29#[cfg(feature = "dtype-decimal")]30Decimal(_, _) => Box::new(VGR::new(31dtype,32VarStdReducer::<Float64Type> {33is_std,34ddof,35needs_cast: true,36_phantom: PhantomData,37},38)),39Duration(..) => todo!(),40Null => Box::new(super::NullGroupedReduction::new(Scalar::null(41DataType::Null,42))),43_ => {44polars_bail!(InvalidOperation: "`{op_name}` operation not supported for dtype `{dtype}`")45},46})47}4849struct VarStdReducer<T> {50is_std: bool,51ddof: u8,52needs_cast: bool,53_phantom: PhantomData<T>,54}5556impl<T> Clone for VarStdReducer<T> {57fn clone(&self) -> Self {58Self {59is_std: self.is_std,60ddof: self.ddof,61needs_cast: self.needs_cast,62_phantom: PhantomData,63}64}65}6667impl<T: PolarsNumericType> Reducer for VarStdReducer<T> {68type Dtype = T;69type Value = VarState;7071fn init(&self) -> Self::Value {72VarState::default()73}7475fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> {76if self.needs_cast {77Cow::Owned(s.cast(&DataType::Float64).unwrap())78} else {79Cow::Borrowed(s)80}81}8283fn combine(&self, a: &mut Self::Value, b: &Self::Value) {84a.combine(b)85}8687#[inline(always)]88fn reduce_one(&self, a: &mut Self::Value, b: Option<T::Native>, _seq_id: u64) {89if let Some(x) = b {90a.insert_one(x.as_());91}92}9394fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>, _seq_id: u64) {95for arr in ca.downcast_iter() {96v.combine(&polars_compute::moment::var(arr))97}98}99100fn finish(101&self,102v: Vec<Self::Value>,103m: Option<Bitmap>,104dtype: &DataType,105) -> PolarsResult<Series> {106assert!(m.is_none());107match dtype {108#[cfg(feature = "dtype-f16")]109DataType::Float16 => {110let ca: Float16Chunked = v111.into_iter()112.map(|s| {113let var = s.finalize(self.ddof);114let out = if self.is_std { var.map(f64::sqrt) } else { var };115out.map(|v| v.as_())116})117.collect_ca(PlSmallStr::EMPTY);118Ok(ca.into_series())119},120DataType::Float32 => {121let ca: Float32Chunked = v122.into_iter()123.map(|s| {124let var = s.finalize(self.ddof);125let out = if self.is_std { var.map(f64::sqrt) } else { var };126out.map(|v| v as f32)127})128.collect_ca(PlSmallStr::EMPTY);129Ok(ca.into_series())130},131_ => {132let ca: Float64Chunked = v133.into_iter()134.map(|s| {135let var = s.finalize(self.ddof);136if self.is_std { var.map(f64::sqrt) } else { var }137})138.collect_ca(PlSmallStr::EMPTY);139Ok(ca.into_series())140},141}142}143}144145#[derive(Clone)]146struct BoolVarStdReducer {147is_std: bool,148ddof: u8,149}150151impl Reducer for BoolVarStdReducer {152type Dtype = BooleanType;153type Value = (usize, usize);154155fn init(&self) -> Self::Value {156(0, 0)157}158159fn combine(&self, a: &mut Self::Value, b: &Self::Value) {160a.0 += b.0;161a.1 += b.1;162}163164#[inline(always)]165fn reduce_one(&self, a: &mut Self::Value, b: Option<bool>, _seq_id: u64) {166a.0 += b.unwrap_or(false) as usize;167a.1 += b.is_some() as usize;168}169170fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray<Self::Dtype>, _seq_id: u64) {171v.0 += ca.sum().unwrap_or(0) as usize;172v.1 += ca.len() - ca.null_count();173}174175fn finish(176&self,177v: Vec<Self::Value>,178m: Option<Bitmap>,179_dtype: &DataType,180) -> PolarsResult<Series> {181assert!(m.is_none());182let ca: Float64Chunked = v183.into_iter()184.map(|v| {185if v.1 <= self.ddof as usize {186return None;187}188189let sum = v.0 as f64; // Both the sum and sum-of-squares, letting us simplify.190let n = v.1;191let var = sum * (1.0 - sum / n as f64) / ((n - self.ddof as usize) as f64);192if self.is_std {193Some(var.sqrt())194} else {195Some(var)196}197})198.collect_ca(PlSmallStr::EMPTY);199Ok(ca.into_series())200}201}202203204