Path: blob/main/crates/polars-ops/src/series/ops/fused.rs
6939 views
use arrow::array::PrimitiveArray;1use arrow::compute::utils::combine_validities_and3;2use polars_core::prelude::*;3use polars_core::utils::align_chunks_ternary;4use polars_core::with_match_physical_numeric_polars_type;56// a + (b * c)7fn fma_arr<T: NumericNative>(8a: &PrimitiveArray<T>,9b: &PrimitiveArray<T>,10c: &PrimitiveArray<T>,11) -> PrimitiveArray<T> {12assert_eq!(a.len(), b.len());13let validity = combine_validities_and3(a.validity(), b.validity(), c.validity());14let a = a.values().as_slice();15let b = b.values().as_slice();16let c = c.values().as_slice();1718assert_eq!(a.len(), b.len());19assert_eq!(b.len(), c.len());20let out = a21.iter()22.zip(b.iter())23.zip(c.iter())24.map(|((a, b), c)| *a + (*b * *c))25.collect::<Vec<_>>();26PrimitiveArray::from_data_default(out.into(), validity)27}2829fn fma_ca<T: PolarsNumericType>(30a: &ChunkedArray<T>,31b: &ChunkedArray<T>,32c: &ChunkedArray<T>,33) -> ChunkedArray<T> {34let (a, b, c) = align_chunks_ternary(a, b, c);35let chunks = a36.downcast_iter()37.zip(b.downcast_iter())38.zip(c.downcast_iter())39.map(|((a, b), c)| fma_arr(a, b, c));40ChunkedArray::from_chunk_iter(a.name().clone(), chunks)41}4243pub fn fma_columns(a: &Column, b: &Column, c: &Column) -> Column {44if a.len() == b.len() && a.len() == c.len() {45with_match_physical_numeric_polars_type!(a.dtype(), |$T| {46let a: &ChunkedArray<$T> = a.as_materialized_series().as_ref().as_ref().as_ref();47let b: &ChunkedArray<$T> = b.as_materialized_series().as_ref().as_ref().as_ref();48let c: &ChunkedArray<$T> = c.as_materialized_series().as_ref().as_ref().as_ref();4950fma_ca(a, b, c).into_column()51})52} else {53(a.as_materialized_series()54+ &(b.as_materialized_series() * c.as_materialized_series()).unwrap())55.unwrap()56.into()57}58}5960// a - (b * c)61fn fsm_arr<T: NumericNative>(62a: &PrimitiveArray<T>,63b: &PrimitiveArray<T>,64c: &PrimitiveArray<T>,65) -> PrimitiveArray<T> {66assert_eq!(a.len(), b.len());67let validity = combine_validities_and3(a.validity(), b.validity(), c.validity());68let a = a.values().as_slice();69let b = b.values().as_slice();70let c = c.values().as_slice();7172assert_eq!(a.len(), b.len());73assert_eq!(b.len(), c.len());74let out = a75.iter()76.zip(b.iter())77.zip(c.iter())78.map(|((a, b), c)| *a - (*b * *c))79.collect::<Vec<_>>();80PrimitiveArray::from_data_default(out.into(), validity)81}8283fn fsm_ca<T: PolarsNumericType>(84a: &ChunkedArray<T>,85b: &ChunkedArray<T>,86c: &ChunkedArray<T>,87) -> ChunkedArray<T> {88let (a, b, c) = align_chunks_ternary(a, b, c);89let chunks = a90.downcast_iter()91.zip(b.downcast_iter())92.zip(c.downcast_iter())93.map(|((a, b), c)| fsm_arr(a, b, c));94ChunkedArray::from_chunk_iter(a.name().clone(), chunks)95}9697pub fn fsm_columns(a: &Column, b: &Column, c: &Column) -> Column {98if a.len() == b.len() && a.len() == c.len() {99with_match_physical_numeric_polars_type!(a.dtype(), |$T| {100let a: &ChunkedArray<$T> = a.as_materialized_series().as_ref().as_ref().as_ref();101let b: &ChunkedArray<$T> = b.as_materialized_series().as_ref().as_ref().as_ref();102let c: &ChunkedArray<$T> = c.as_materialized_series().as_ref().as_ref().as_ref();103104fsm_ca(a, b, c).into_column()105})106} else {107(a.as_materialized_series()108- &(b.as_materialized_series() * c.as_materialized_series()).unwrap())109.unwrap()110.into()111}112}113114fn fms_arr<T: NumericNative>(115a: &PrimitiveArray<T>,116b: &PrimitiveArray<T>,117c: &PrimitiveArray<T>,118) -> PrimitiveArray<T> {119assert_eq!(a.len(), b.len());120let validity = combine_validities_and3(a.validity(), b.validity(), c.validity());121let a = a.values().as_slice();122let b = b.values().as_slice();123let c = c.values().as_slice();124125assert_eq!(a.len(), b.len());126assert_eq!(b.len(), c.len());127let out = a128.iter()129.zip(b.iter())130.zip(c.iter())131.map(|((a, b), c)| (*a * *b) - *c)132.collect::<Vec<_>>();133PrimitiveArray::from_data_default(out.into(), validity)134}135136fn fms_ca<T: PolarsNumericType>(137a: &ChunkedArray<T>,138b: &ChunkedArray<T>,139c: &ChunkedArray<T>,140) -> ChunkedArray<T> {141let (a, b, c) = align_chunks_ternary(a, b, c);142let chunks = a143.downcast_iter()144.zip(b.downcast_iter())145.zip(c.downcast_iter())146.map(|((a, b), c)| fms_arr(a, b, c));147ChunkedArray::from_chunk_iter(a.name().clone(), chunks)148}149150pub fn fms_columns(a: &Column, b: &Column, c: &Column) -> Column {151if a.len() == b.len() && a.len() == c.len() {152with_match_physical_numeric_polars_type!(a.dtype(), |$T| {153let a: &ChunkedArray<$T> = a.as_materialized_series().as_ref().as_ref().as_ref();154let b: &ChunkedArray<$T> = b.as_materialized_series().as_ref().as_ref().as_ref();155let c: &ChunkedArray<$T> = c.as_materialized_series().as_ref().as_ref().as_ref();156157fms_ca(a, b, c).into_column()158})159} else {160(&(a.as_materialized_series() * b.as_materialized_series()).unwrap()161- c.as_materialized_series())162.unwrap()163.into()164}165}166167168