Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/horizontal.rs
6939 views
1
use std::borrow::Cow;
2
3
use polars_core::chunked_array::cast::CastOptions;
4
use polars_core::prelude::*;
5
use polars_core::series::arithmetic::coerce_lhs_rhs;
6
use polars_core::utils::dtypes_to_supertype;
7
use polars_core::{POOL, with_match_physical_numeric_polars_type};
8
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
9
10
fn validate_column_lengths(cs: &[Column]) -> PolarsResult<()> {
11
let mut length = 1;
12
for c in cs {
13
let len = c.len();
14
if len != 1 && len != length {
15
if length == 1 {
16
length = len;
17
} else {
18
polars_bail!(ShapeMismatch: "cannot evaluate two Series of different lengths ({len} and {length})");
19
}
20
}
21
}
22
Ok(())
23
}
24
25
pub trait MinMaxHorizontal {
26
/// Aggregate the column horizontally to their min values.
27
fn min_horizontal(&self) -> PolarsResult<Option<Column>>;
28
/// Aggregate the column horizontally to their max values.
29
fn max_horizontal(&self) -> PolarsResult<Option<Column>>;
30
}
31
32
impl MinMaxHorizontal for DataFrame {
33
fn min_horizontal(&self) -> PolarsResult<Option<Column>> {
34
min_horizontal(self.get_columns())
35
}
36
fn max_horizontal(&self) -> PolarsResult<Option<Column>> {
37
max_horizontal(self.get_columns())
38
}
39
}
40
41
#[derive(Copy, Clone, Debug, PartialEq)]
42
pub enum NullStrategy {
43
Ignore,
44
Propagate,
45
}
46
47
pub trait SumMeanHorizontal {
48
/// Sum all values horizontally across columns.
49
fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
50
51
/// Compute the mean of all numeric values horizontally across columns.
52
fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
53
}
54
55
impl SumMeanHorizontal for DataFrame {
56
fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
57
sum_horizontal(self.get_columns(), null_strategy)
58
}
59
fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
60
mean_horizontal(self.get_columns(), null_strategy)
61
}
62
}
63
64
fn min_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
65
where
66
T: PolarsNumericType,
67
T::Native: PartialOrd,
68
{
69
let op = |l: T::Native, r: T::Native| {
70
if l < r { l } else { r }
71
};
72
arity::binary_elementwise_values(left, right, op)
73
}
74
75
fn max_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
76
where
77
T: PolarsNumericType,
78
T::Native: PartialOrd,
79
{
80
let op = |l: T::Native, r: T::Native| {
81
if l > r { l } else { r }
82
};
83
arity::binary_elementwise_values(left, right, op)
84
}
85
86
fn min_max_binary_columns(left: &Column, right: &Column, min: bool) -> PolarsResult<Column> {
87
if left.dtype().to_physical().is_primitive_numeric()
88
&& right.dtype().to_physical().is_primitive_numeric()
89
&& left.null_count() == 0
90
&& right.null_count() == 0
91
&& left.len() == right.len()
92
{
93
match (left, right) {
94
(Column::Series(left), Column::Series(right)) => {
95
let (lhs, rhs) = coerce_lhs_rhs(left, right)?;
96
let logical = lhs.dtype();
97
let lhs = lhs.to_physical_repr();
98
let rhs = rhs.to_physical_repr();
99
100
with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| {
101
let a: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();
102
let b: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref();
103
104
unsafe {
105
if min {
106
min_binary(a, b).into_series().from_physical_unchecked(logical)
107
} else {
108
max_binary(a, b).into_series().from_physical_unchecked(logical)
109
}
110
}
111
})
112
.map(Column::from)
113
},
114
_ => {
115
let mask = if min {
116
left.lt(right)?
117
} else {
118
left.gt(right)?
119
};
120
121
left.zip_with(&mask, right)
122
},
123
}
124
} else {
125
let mask = if min {
126
left.lt(right)? & left.is_not_null() | right.is_null()
127
} else {
128
left.gt(right)? & left.is_not_null() | right.is_null()
129
};
130
left.zip_with(&mask, right)
131
}
132
}
133
134
pub fn max_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
135
validate_column_lengths(columns)?;
136
137
let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false);
138
139
match columns.len() {
140
0 => Ok(None),
141
1 => Ok(Some(columns[0].clone())),
142
2 => max_fn(&columns[0], &columns[1]).map(Some),
143
_ => {
144
// the try_reduce_with is a bit slower in parallelism,
145
// but I don't think it matters here as we parallelize over columns, not over elements
146
POOL.install(|| {
147
columns
148
.par_iter()
149
.map(|s| Ok(Cow::Borrowed(s)))
150
.try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned))
151
// we can unwrap the option, because we are certain there is a column
152
// we started this operation on 3 columns
153
.unwrap()
154
.map(|cow| Some(cow.into_owned()))
155
})
156
},
157
}
158
}
159
160
pub fn min_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
161
validate_column_lengths(columns)?;
162
163
let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true);
164
165
match columns.len() {
166
0 => Ok(None),
167
1 => Ok(Some(columns[0].clone())),
168
2 => min_fn(&columns[0], &columns[1]).map(Some),
169
_ => {
170
// the try_reduce_with is a bit slower in parallelism,
171
// but I don't think it matters here as we parallelize over columns, not over elements
172
POOL.install(|| {
173
columns
174
.par_iter()
175
.map(|s| Ok(Cow::Borrowed(s)))
176
.try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned))
177
// we can unwrap the option, because we are certain there is a column
178
// we started this operation on 3 columns
179
.unwrap()
180
.map(|cow| Some(cow.into_owned()))
181
})
182
},
183
}
184
}
185
186
pub fn sum_horizontal(
187
columns: &[Column],
188
null_strategy: NullStrategy,
189
) -> PolarsResult<Option<Column>> {
190
validate_column_lengths(columns)?;
191
let ignore_nulls = null_strategy == NullStrategy::Ignore;
192
193
let apply_null_strategy = |s: Series| -> PolarsResult<Series> {
194
if ignore_nulls && s.null_count() > 0 {
195
s.fill_null(FillNullStrategy::Zero)
196
} else {
197
Ok(s)
198
}
199
};
200
201
let sum_fn = |acc: Series, s: Series| -> PolarsResult<Series> {
202
let acc: Series = apply_null_strategy(acc)?;
203
let s = apply_null_strategy(s)?;
204
// This will do owned arithmetic and can be mutable
205
std::ops::Add::add(acc, s)
206
};
207
208
// @scalar-opt
209
let non_null_cols = columns
210
.iter()
211
.filter(|x| x.dtype() != &DataType::Null)
212
.map(|c| c.as_materialized_series())
213
.collect::<Vec<_>>();
214
215
// If we have any null columns and null strategy is not `Ignore`, we can return immediately.
216
if !ignore_nulls && non_null_cols.len() < columns.len() {
217
// We must determine the correct return dtype.
218
let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {
219
DataType::Boolean => IDX_DTYPE,
220
dt => dt,
221
};
222
return Ok(Some(Column::full_null(
223
columns[0].name().clone(),
224
columns[0].len(),
225
&return_dtype,
226
)));
227
}
228
229
match non_null_cols.len() {
230
0 => {
231
if columns.is_empty() {
232
Ok(None)
233
} else {
234
// all columns are null dtype, so result is null dtype
235
Ok(Some(columns[0].clone()))
236
}
237
},
238
1 => Ok(Some(
239
apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean {
240
non_null_cols[0].cast(&IDX_DTYPE)?
241
} else {
242
non_null_cols[0].clone()
243
})?
244
.into(),
245
)),
246
2 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone())
247
.map(Column::from)
248
.map(Some),
249
_ => {
250
// the try_reduce_with is a bit slower in parallelism,
251
// but I don't think it matters here as we parallelize over columns, not over elements
252
let out = POOL.install(|| {
253
non_null_cols
254
.into_par_iter()
255
.cloned()
256
.map(Ok)
257
.try_reduce_with(sum_fn)
258
// We can unwrap because we started with at least 3 columns, so we always get a Some
259
.unwrap()
260
});
261
out.map(Column::from).map(Some)
262
},
263
}
264
}
265
266
pub fn mean_horizontal(
267
columns: &[Column],
268
null_strategy: NullStrategy,
269
) -> PolarsResult<Option<Column>> {
270
validate_column_lengths(columns)?;
271
272
let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| {
273
let dtype = s.dtype();
274
dtype.is_primitive_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null()
275
});
276
277
if !non_numeric_columns.is_empty() {
278
let col = non_numeric_columns.first().cloned();
279
polars_bail!(
280
InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})",
281
col.unwrap().name(),
282
col.unwrap().dtype(),
283
);
284
}
285
let columns = numeric_columns.into_iter().cloned().collect::<Vec<_>>();
286
let num_rows = columns.len();
287
match num_rows {
288
0 => Ok(None),
289
1 => Ok(Some(match columns[0].dtype() {
290
dt if dt != &DataType::Float32 && !dt.is_decimal() => {
291
columns[0].cast(&DataType::Float64)?
292
},
293
_ => columns[0].clone(),
294
})),
295
_ => {
296
let sum = || sum_horizontal(columns.as_slice(), null_strategy);
297
let null_count = || {
298
columns
299
.par_iter()
300
.map(|c| {
301
c.is_null()
302
.into_column()
303
.cast_with_options(&DataType::UInt32, CastOptions::NonStrict)
304
})
305
.reduce_with(|l, r| {
306
let l = l?;
307
let r = r?;
308
let result = std::ops::Add::add(&l, &r)?;
309
PolarsResult::Ok(result)
310
})
311
// we can unwrap the option, because we are certain there is a column
312
// we started this operation on 2 columns
313
.unwrap()
314
};
315
316
let (sum, null_count) = POOL.install(|| rayon::join(sum, null_count));
317
let sum = sum?;
318
let null_count = null_count?;
319
320
// value lengths: len - null_count
321
let value_length: UInt32Chunked = (Column::new_scalar(
322
PlSmallStr::EMPTY,
323
Scalar::from(num_rows as u32),
324
null_count.len(),
325
) - null_count)?
326
.u32()
327
.unwrap()
328
.clone();
329
330
// make sure that we do not divide by zero
331
// by replacing with None
332
let dt = if sum
333
.as_ref()
334
.is_some_and(|s| s.dtype() == &DataType::Float32)
335
{
336
&DataType::Float32
337
} else {
338
&DataType::Float64
339
};
340
let value_length = value_length
341
.set(&value_length.equal(0), None)?
342
.into_column()
343
.cast(dt)?;
344
345
sum.map(|sum| std::ops::Div::div(&sum, &value_length))
346
.transpose()
347
},
348
}
349
}
350
351
pub fn coalesce_columns(s: &[Column]) -> PolarsResult<Column> {
352
// TODO! this can be faster if we have more than two inputs.
353
polars_ensure!(!s.is_empty(), NoData: "cannot coalesce empty list");
354
let mut out = s[0].clone();
355
for s in s {
356
if !out.null_count() == 0 {
357
return Ok(out);
358
} else {
359
let mask = out.is_not_null();
360
out = out
361
.as_materialized_series()
362
.zip_with_same_type(&mask, s.as_materialized_series())?
363
.into();
364
}
365
}
366
Ok(out)
367
}
368
369
#[cfg(test)]
370
mod tests {
371
use super::*;
372
373
#[test]
374
#[cfg_attr(miri, ignore)]
375
fn test_horizontal_agg() {
376
let a = Column::new("a".into(), [1, 2, 6]);
377
let b = Column::new("b".into(), [Some(1), None, None]);
378
let c = Column::new("c".into(), [Some(4), None, Some(3)]);
379
380
let df = DataFrame::new(vec![a, b, c]).unwrap();
381
assert_eq!(
382
Vec::from(
383
df.mean_horizontal(NullStrategy::Ignore)
384
.unwrap()
385
.unwrap()
386
.f64()
387
.unwrap()
388
),
389
&[Some(2.0), Some(2.0), Some(4.5)]
390
);
391
assert_eq!(
392
Vec::from(
393
df.sum_horizontal(NullStrategy::Ignore)
394
.unwrap()
395
.unwrap()
396
.i32()
397
.unwrap()
398
),
399
&[Some(6), Some(2), Some(9)]
400
);
401
assert_eq!(
402
Vec::from(df.min_horizontal().unwrap().unwrap().i32().unwrap()),
403
&[Some(1), Some(2), Some(3)]
404
);
405
assert_eq!(
406
Vec::from(df.max_horizontal().unwrap().unwrap().i32().unwrap()),
407
&[Some(4), Some(2), Some(6)]
408
);
409
}
410
}
411
412