Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/schema.rs
8424 views
1
#[cfg(feature = "dtype-decimal")]
2
use polars_compute::decimal::DEC128_MAX_PREC;
3
use polars_core::series::arithmetic::NumericListOp;
4
use polars_utils::format_pl_smallstr;
5
use recursive::recursive;
6
7
use super::*;
8
use crate::constants::{
9
POLARS_ELEMENT, POLARS_STRUCTFIELDS, get_literal_name, get_pl_element_name,
10
get_pl_structfields_name,
11
};
12
13
fn validate_expr(node: Node, ctx: &ToFieldContext) -> PolarsResult<()> {
14
ctx.arena.get(node).to_field_impl(ctx).map(|_| ())
15
}
16
17
#[derive(Debug)]
18
pub struct ToFieldContext<'a> {
19
arena: &'a Arena<AExpr>,
20
schema: &'a Schema,
21
}
22
23
impl<'a> ToFieldContext<'a> {
24
pub fn new(arena: &'a Arena<AExpr>, schema: &'a Schema) -> Self {
25
Self { arena, schema }
26
}
27
}
28
29
impl AExpr {
30
pub fn to_dtype(&self, ctx: &ToFieldContext<'_>) -> PolarsResult<DataType> {
31
self.to_field(ctx).map(|f| f.dtype)
32
}
33
34
/// Get Field result of the expression. The schema is the input data. The result will
35
/// not be coerced (also known as auto-implode): this is the responsibility of the caller.
36
pub fn to_field(&self, ctx: &ToFieldContext<'_>) -> PolarsResult<Field> {
37
self.to_field_impl(ctx)
38
}
39
40
/// Get Field result of the expression. The schema is the input data.
41
///
42
/// This is taken as `&mut bool` as for some expressions this is determined by the upper node
43
/// (e.g. `alias`, `cast`).
44
#[recursive]
45
pub fn to_field_impl(&self, ctx: &ToFieldContext) -> PolarsResult<Field> {
46
use AExpr::*;
47
use DataType::*;
48
match self {
49
Element => ctx
50
.schema
51
.get_field(POLARS_ELEMENT)
52
.ok_or_else(|| polars_err!(invalid_element_use)),
53
54
Len => Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE)),
55
#[cfg(feature = "dynamic_group_by")]
56
Rolling { function, .. } => {
57
let e = ctx.arena.get(*function);
58
let mut field = e.to_field_impl(ctx)?;
59
// Implicit implode
60
if !is_scalar_ae(*function, ctx.arena) {
61
field.dtype = field.dtype.implode();
62
}
63
Ok(field)
64
},
65
Over {
66
function,
67
partition_by,
68
order_by,
69
mapping,
70
} => {
71
for node in partition_by {
72
validate_expr(*node, ctx)?;
73
}
74
if let Some((node, _)) = order_by {
75
validate_expr(*node, ctx)?;
76
}
77
78
let e = ctx.arena.get(*function);
79
let mut field = e.to_field_impl(ctx)?;
80
81
if matches!(mapping, WindowMapping::Join) && !is_scalar_ae(*function, ctx.arena) {
82
field.dtype = field.dtype.implode();
83
}
84
85
Ok(field)
86
},
87
Explode { expr, .. } => {
88
let field = ctx.arena.get(*expr).to_field_impl(ctx)?;
89
let field = match field.dtype() {
90
List(inner) => Field::new(field.name().clone(), *inner.clone()),
91
#[cfg(feature = "dtype-array")]
92
Array(inner, ..) => Field::new(field.name().clone(), *inner.clone()),
93
_ => field,
94
};
95
96
Ok(field)
97
},
98
Column(name) => ctx
99
.schema
100
.get_field(name)
101
.ok_or_else(|| PolarsError::ColumnNotFound(name.to_string().into())),
102
#[cfg(feature = "dtype-struct")]
103
StructField(name) => {
104
let struct_field = ctx
105
.schema
106
.get_field(POLARS_STRUCTFIELDS)
107
.ok_or_else(|| polars_err!(invalid_field_use))?;
108
let DataType::Struct(fields) = struct_field.dtype() else {
109
return Err(polars_err!(
110
InvalidOperation: "expected `Struct` dtype for `with_fields` Expr, got `{}`",
111
struct_field.dtype()));
112
};
113
// @NOTE. Linear search performance is not ideal. An alternative approach
114
// would be to map each field to a new column with a temporary name (see streaming engine),
115
// and extend the schema accordingly.
116
for f in fields {
117
if f.name() == name {
118
return Ok(f.clone());
119
}
120
}
121
Err(PolarsError::StructFieldNotFound(name.to_string().into()))
122
},
123
Literal(sv) => Ok(match sv {
124
LiteralValue::Series(s) => s.field().into_owned(),
125
_ => Field::new(sv.output_column_name(), sv.get_datatype()),
126
}),
127
BinaryExpr { left, right, op } => {
128
use DataType::*;
129
130
let field = match op {
131
Operator::Lt
132
| Operator::Gt
133
| Operator::Eq
134
| Operator::NotEq
135
| Operator::LogicalAnd
136
| Operator::LtEq
137
| Operator::GtEq
138
| Operator::NotEqValidity
139
| Operator::EqValidity
140
| Operator::LogicalOr => {
141
let out_field;
142
let out_name = {
143
out_field = ctx.arena.get(*left).to_field_impl(ctx)?;
144
out_field.name()
145
};
146
Field::new(out_name.clone(), Boolean)
147
},
148
Operator::TrueDivide => get_truediv_field(*left, *right, ctx)?,
149
_ => get_arithmetic_field(*left, *right, *op, ctx)?,
150
};
151
152
Ok(field)
153
},
154
Sort { expr, .. } => ctx.arena.get(*expr).to_field_impl(ctx),
155
Gather { expr, idx, .. } => {
156
validate_expr(*idx, ctx)?;
157
ctx.arena.get(*expr).to_field_impl(ctx)
158
},
159
SortBy { expr, .. } => ctx.arena.get(*expr).to_field_impl(ctx),
160
Filter { input, by } => {
161
validate_expr(*by, ctx)?;
162
ctx.arena.get(*input).to_field_impl(ctx)
163
},
164
Agg(agg) => {
165
use IRAggExpr::*;
166
match agg {
167
Max { input: expr, .. }
168
| Min { input: expr, .. }
169
| First(expr)
170
| FirstNonNull(expr)
171
| Last(expr)
172
| LastNonNull(expr) => ctx.arena.get(*expr).to_field_impl(ctx),
173
Item { input: expr, .. } => ctx.arena.get(*expr).to_field_impl(ctx),
174
Sum(expr) => {
175
let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;
176
let dt = match field.dtype() {
177
String | Binary | BinaryOffset | List(_) => {
178
polars_bail!(
179
InvalidOperation: "`sum` operation not supported for dtype `{}`",
180
field.dtype()
181
)
182
},
183
#[cfg(feature = "dtype-array")]
184
Array(_, _) => {
185
polars_bail!(
186
InvalidOperation: "`sum` operation not supported for dtype `{}`",
187
field.dtype()
188
)
189
},
190
#[cfg(feature = "dtype-struct")]
191
Struct(_) => {
192
polars_bail!(
193
InvalidOperation: "`sum` operation not supported for dtype `{}`",
194
field.dtype()
195
)
196
},
197
Boolean => Some(IDX_DTYPE),
198
UInt8 | Int8 | Int16 | UInt16 => Some(Int64),
199
_ => None,
200
};
201
if let Some(dt) = dt {
202
field.coerce(dt);
203
}
204
Ok(field)
205
},
206
Median(expr) => {
207
let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];
208
let mapper = FieldsMapper::new(&field);
209
mapper.moment_dtype()
210
},
211
Mean(expr) => {
212
let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];
213
let mapper = FieldsMapper::new(&field);
214
mapper.moment_dtype()
215
},
216
Implode(expr) => {
217
let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;
218
field.coerce(DataType::List(field.dtype().clone().into()));
219
Ok(field)
220
},
221
Std(expr, _) => {
222
let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];
223
let mapper = FieldsMapper::new(&field);
224
mapper.moment_dtype()
225
},
226
Var(expr, _) => {
227
let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];
228
let mapper = FieldsMapper::new(&field);
229
mapper.var_dtype()
230
},
231
NUnique(expr) => {
232
let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;
233
field.coerce(IDX_DTYPE);
234
Ok(field)
235
},
236
Count { input, .. } => {
237
let mut field = ctx.arena.get(*input).to_field_impl(ctx)?;
238
field.coerce(IDX_DTYPE);
239
Ok(field)
240
},
241
AggGroups(expr) => {
242
let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;
243
field.coerce(IDX_DTYPE.implode());
244
Ok(field)
245
},
246
Quantile { expr, .. } => {
247
let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];
248
let mapper = FieldsMapper::new(&field);
249
mapper.moment_dtype()
250
},
251
}
252
},
253
Cast { expr, dtype, .. } => {
254
let field = ctx.arena.get(*expr).to_field_impl(ctx)?;
255
Ok(Field::new(field.name().clone(), dtype.clone()))
256
},
257
Ternary { truthy, falsy, .. } => {
258
// During aggregation:
259
// left: col(foo): list<T> nesting: 1
260
// right; col(foo).first(): T nesting: 0
261
// col(foo) + col(foo).first() will have nesting 1 as we still maintain the groups list.
262
let mut truthy = ctx.arena.get(*truthy).to_field_impl(ctx)?;
263
let falsy = ctx.arena.get(*falsy).to_field_impl(ctx)?;
264
265
let st = if let DataType::Null = *truthy.dtype() {
266
falsy.dtype().clone()
267
} else {
268
try_get_supertype(truthy.dtype(), falsy.dtype())?
269
};
270
271
truthy.coerce(st);
272
Ok(truthy)
273
},
274
AnonymousFunction {
275
input,
276
function,
277
fmt_str,
278
..
279
} => {
280
let fields = func_args_to_fields(input, ctx)?;
281
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", fmt_str);
282
let function = function.clone().materialize()?;
283
let out = function.get_field(ctx.schema, &fields)?;
284
Ok(out)
285
},
286
AnonymousAgg {
287
input,
288
function,
289
fmt_str,
290
..
291
} => {
292
let fields = func_args_to_fields(input, ctx)?;
293
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", fmt_str);
294
let function = function.clone().materialize()?;
295
let out = function.get_field(ctx.schema, &fields)?;
296
Ok(out)
297
},
298
Eval {
299
expr,
300
evaluation,
301
variant,
302
} => {
303
let field = ctx.arena.get(*expr).to_field_impl(ctx)?;
304
305
let element_dtype = variant.element_dtype(field.dtype())?;
306
let mut evaluation_schema = ctx.schema.clone();
307
evaluation_schema.insert(get_pl_element_name(), element_dtype.clone());
308
let mut output_field = ctx
309
.arena
310
.get(*evaluation)
311
.to_field_impl(&ToFieldContext::new(ctx.arena, &evaluation_schema))?;
312
output_field.dtype = output_field.dtype.materialize_unknown(false)?;
313
let eval_is_scalar = is_scalar_ae(*evaluation, ctx.arena);
314
315
output_field.dtype =
316
variant.output_dtype(field.dtype(), output_field.dtype, eval_is_scalar)?;
317
output_field.name = field.name;
318
319
Ok(output_field)
320
},
321
#[cfg(feature = "dtype-struct")]
322
StructEval { expr, evaluation } => {
323
let struct_field = ctx.arena.get(*expr).to_field_impl(ctx)?;
324
let mut evaluation_schema = ctx.schema.clone();
325
evaluation_schema.insert(get_pl_structfields_name(), struct_field.dtype().clone());
326
327
let eval_fields = func_args_to_fields(
328
evaluation,
329
&ToFieldContext::new(ctx.arena, &evaluation_schema),
330
)?;
331
332
// Merge evaluation fields into the expr Struct
333
if let DataType::Struct(expr_fields) = struct_field.dtype() {
334
let mut fields_map =
335
PlIndexMap::with_capacity(expr_fields.len() + eval_fields.len());
336
for field in expr_fields {
337
fields_map.insert(field.name(), field.dtype());
338
}
339
for field in &eval_fields {
340
fields_map.insert(field.name(), field.dtype());
341
}
342
let dtype = DataType::Struct(
343
fields_map
344
.iter()
345
.map(|(&name, &dtype)| Field::new(name.clone(), dtype.clone()))
346
.collect(),
347
);
348
let mut out = struct_field.clone();
349
out.coerce(dtype);
350
Ok(out)
351
} else {
352
let dt = struct_field.dtype();
353
polars_bail!(op = "with_fields", got = dt, expected = "Struct")
354
}
355
},
356
Function {
357
function,
358
input,
359
options: _,
360
} => {
361
#[cfg(feature = "strings")]
362
{
363
if input.is_empty()
364
&& matches!(
365
&function,
366
IRFunctionExpr::StringExpr(IRStringFunction::Format { .. })
367
)
368
{
369
return Ok(Field::new(get_literal_name(), DataType::String));
370
}
371
}
372
373
let fields = func_args_to_fields(input, ctx)?;
374
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function);
375
let out = function.get_field(ctx.schema, &fields)?;
376
377
Ok(out)
378
},
379
Slice {
380
input,
381
offset,
382
length,
383
} => {
384
validate_expr(*offset, ctx)?;
385
validate_expr(*length, ctx)?;
386
387
ctx.arena.get(*input).to_field_impl(ctx)
388
},
389
}
390
}
391
392
pub fn to_name(&self, expr_arena: &Arena<AExpr>) -> PlSmallStr {
393
use AExpr::*;
394
use IRAggExpr::*;
395
match self {
396
Element => PlSmallStr::EMPTY,
397
Len => crate::constants::get_len_name(),
398
#[cfg(feature = "dynamic_group_by")]
399
Rolling {
400
function,
401
index_column: _,
402
period: _,
403
offset: _,
404
closed_window: _,
405
} => expr_arena.get(*function).to_name(expr_arena),
406
Over {
407
function: expr,
408
partition_by: _,
409
order_by: _,
410
mapping: _,
411
}
412
| BinaryExpr { left: expr, .. }
413
| Explode { expr, .. }
414
| Sort { expr, .. }
415
| Gather { expr, .. }
416
| SortBy { expr, .. }
417
| Filter { input: expr, .. }
418
| Cast { expr, .. }
419
| Ternary { truthy: expr, .. }
420
| Eval { expr, .. }
421
| Slice { input: expr, .. }
422
| Agg(Min { input: expr, .. })
423
| Agg(Max { input: expr, .. })
424
| Agg(First(expr))
425
| Agg(FirstNonNull(expr))
426
| Agg(Last(expr))
427
| Agg(LastNonNull(expr))
428
| Agg(Item { input: expr, .. })
429
| Agg(Sum(expr))
430
| Agg(Median(expr))
431
| Agg(Mean(expr))
432
| Agg(Implode(expr))
433
| Agg(Std(expr, _))
434
| Agg(Var(expr, _))
435
| Agg(NUnique(expr))
436
| Agg(Count { input: expr, .. })
437
| Agg(AggGroups(expr))
438
| Agg(Quantile { expr, .. }) => expr_arena.get(*expr).to_name(expr_arena),
439
AnonymousFunction { input, fmt_str, .. } | AnonymousAgg { input, fmt_str, .. } => {
440
if input.is_empty() {
441
fmt_str.as_ref().clone()
442
} else {
443
input[0].output_name().clone()
444
}
445
},
446
#[cfg(feature = "dtype-struct")]
447
StructEval { expr, .. } => expr_arena.get(*expr).to_name(expr_arena),
448
Function {
449
input, function, ..
450
} => match function.output_name().and_then(|v| v.into_inner()) {
451
Some(name) => name,
452
None if input.is_empty() => format_pl_smallstr!("{}", &function),
453
None => input[0].output_name().clone(),
454
},
455
Column(name) => name.clone(),
456
#[cfg(feature = "dtype-struct")]
457
StructField(name) => name.clone(),
458
Literal(lv) => lv.output_column_name().clone(),
459
}
460
}
461
}
462
463
fn func_args_to_fields(input: &[ExprIR], ctx: &ToFieldContext) -> PolarsResult<Vec<Field>> {
464
input
465
.iter()
466
.map(|e| {
467
ctx.arena.get(e.node()).to_field_impl(ctx).map(|mut field| {
468
field.name = e.output_name().clone();
469
field
470
})
471
})
472
.collect()
473
}
474
475
#[allow(clippy::too_many_arguments)]
476
fn get_arithmetic_field(
477
left: Node,
478
right: Node,
479
op: Operator,
480
ctx: &ToFieldContext,
481
) -> PolarsResult<Field> {
482
use DataType::*;
483
let left_ae = ctx.arena.get(left);
484
let right_ae = ctx.arena.get(right);
485
486
// don't traverse tree until strictly needed. Can have terrible performance.
487
// # 3210
488
489
// take the left field as a whole.
490
// don't take dtype and name separate as that splits the tree every node
491
// leading to quadratic behavior. # 4736
492
//
493
// further right_type is only determined when needed.
494
let mut left_field = left_ae.to_field_impl(ctx)?;
495
496
let super_type = match op {
497
Operator::Minus => {
498
let right_type = right_ae.to_field_impl(ctx)?.dtype;
499
match (&left_field.dtype, &right_type) {
500
#[cfg(feature = "dtype-struct")]
501
(Struct(_), Struct(_)) => {
502
return Ok(left_field);
503
},
504
// This matches the engine output. TODO: revisit pending resolution of GH issue #23797
505
#[cfg(feature = "dtype-struct")]
506
(Struct(_), r) if r.is_numeric() => {
507
return Ok(left_field);
508
},
509
(Duration(_), Datetime(_, _))
510
| (Datetime(_, _), Duration(_))
511
| (Duration(_), Date)
512
| (Date, Duration(_))
513
| (Duration(_), Time)
514
| (Time, Duration(_)) => try_get_supertype(left_field.dtype(), &right_type)?,
515
(Datetime(tu, _), Date) | (Date, Datetime(tu, _)) => Duration(*tu),
516
// T - T != T if T is a datetime / date
517
(Datetime(tul, _), Datetime(tur, _)) => Duration(get_time_units(tul, tur)),
518
(_, Datetime(_, _)) | (Datetime(_, _), _) => {
519
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
520
},
521
(Date, Date) => Duration(TimeUnit::Microseconds),
522
(_, Date) | (Date, _) => {
523
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
524
},
525
(Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)),
526
(_, Duration(_)) | (Duration(_), _) => {
527
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
528
},
529
(Time, Time) => Duration(TimeUnit::Nanoseconds),
530
(_, Time) | (Time, _) => {
531
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
532
},
533
(l @ List(a), r @ List(b))
534
if ![a, b]
535
.into_iter()
536
.all(|x| x.is_supported_list_arithmetic_input()) =>
537
{
538
polars_bail!(
539
InvalidOperation:
540
"cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
541
"sub", l, r,
542
)
543
},
544
(list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {
545
// TODO: This should not use `try_get_supertype()`! It should instead recursively use the enclosing match block.
546
// Otherwise we will silently permit addition operations between logical types (see above).
547
// This currently doesn't cause any problems because the list arithmetic implementation checks and raises errors
548
// if the leaf types aren't numeric, but it means we don't raise an error until execution and the DSL schema
549
// may be incorrect.
550
list_dtype.cast_leaf(NumericListOp::sub().try_get_leaf_supertype(
551
list_dtype.leaf_dtype(),
552
other_dtype.leaf_dtype(),
553
)?)
554
},
555
#[cfg(feature = "dtype-array")]
556
(list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {
557
list_dtype.cast_leaf(try_get_supertype(
558
list_dtype.leaf_dtype(),
559
other_dtype.leaf_dtype(),
560
)?)
561
},
562
#[cfg(feature = "dtype-decimal")]
563
(Decimal(_, scale_left), Decimal(_, scale_right)) => {
564
Decimal(DEC128_MAX_PREC, *scale_left.max(scale_right))
565
},
566
(left, right) => try_get_supertype(left, right)?,
567
}
568
},
569
Operator::Plus => {
570
let right_type = right_ae.to_field_impl(ctx)?.dtype;
571
match (&left_field.dtype, &right_type) {
572
#[cfg(feature = "dtype-struct")]
573
(Struct(_), Struct(_)) => {
574
return Ok(left_field);
575
},
576
// This matches the engine output. TODO: revisit pending resolution of GH issue #23797
577
#[cfg(feature = "dtype-struct")]
578
(Struct(_), r) if r.is_numeric() => {
579
return Ok(left_field);
580
},
581
(Duration(_), Datetime(_, _))
582
| (Datetime(_, _), Duration(_))
583
| (Duration(_), Date)
584
| (Date, Duration(_))
585
| (Duration(_), Time)
586
| (Time, Duration(_)) => try_get_supertype(left_field.dtype(), &right_type)?,
587
(_, Datetime(_, _))
588
| (Datetime(_, _), _)
589
| (_, Date)
590
| (Date, _)
591
| (Time, _)
592
| (_, Time) => {
593
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
594
},
595
(Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)),
596
(_, Duration(_)) | (Duration(_), _) => {
597
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
598
},
599
(Boolean, Boolean) => IDX_DTYPE,
600
(l @ List(a), r @ List(b))
601
if ![a, b]
602
.into_iter()
603
.all(|x| x.is_supported_list_arithmetic_input()) =>
604
{
605
polars_bail!(
606
InvalidOperation:
607
"cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
608
"add", l, r,
609
)
610
},
611
(list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {
612
list_dtype.cast_leaf(NumericListOp::add().try_get_leaf_supertype(
613
list_dtype.leaf_dtype(),
614
other_dtype.leaf_dtype(),
615
)?)
616
},
617
#[cfg(feature = "dtype-array")]
618
(list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {
619
list_dtype.cast_leaf(try_get_supertype(
620
list_dtype.leaf_dtype(),
621
other_dtype.leaf_dtype(),
622
)?)
623
},
624
#[cfg(feature = "dtype-decimal")]
625
(Decimal(_, scale_left), Decimal(_, scale_right)) => {
626
Decimal(DEC128_MAX_PREC, *scale_left.max(scale_right))
627
},
628
(left, right) => try_get_supertype(left, right)?,
629
}
630
},
631
_ => {
632
let right_type = right_ae.to_field_impl(ctx)?.dtype;
633
634
match (&left_field.dtype, &right_type) {
635
#[cfg(feature = "dtype-struct")]
636
(Struct(_), Struct(_)) => {
637
return Ok(left_field);
638
},
639
// This matches the engine output. TODO: revisit pending resolution of GH issue #23797
640
#[cfg(feature = "dtype-struct")]
641
(Struct(_), r) if r.is_numeric() => {
642
return Ok(left_field);
643
},
644
(Datetime(_, _), _)
645
| (_, Datetime(_, _))
646
| (Time, _)
647
| (_, Time)
648
| (Date, _)
649
| (_, Date) => {
650
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
651
},
652
(Duration(_), Duration(_)) => {
653
// True divide handled somewhere else
654
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
655
},
656
(l, Duration(_)) if l.is_primitive_numeric() => match op {
657
Operator::Multiply => {
658
left_field.coerce(right_type);
659
return Ok(left_field);
660
},
661
_ => {
662
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
663
},
664
},
665
(Duration(_), r) if r.is_primitive_numeric() => match op {
666
Operator::Multiply => {
667
return Ok(left_field);
668
},
669
_ => {
670
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
671
},
672
},
673
#[cfg(feature = "dtype-decimal")]
674
(Decimal(_, scale_left), Decimal(_, scale_right)) => {
675
let dtype = Decimal(DEC128_MAX_PREC, *scale_left.max(scale_right));
676
left_field.coerce(dtype);
677
return Ok(left_field);
678
},
679
680
(l @ List(a), r @ List(b))
681
if ![a, b]
682
.into_iter()
683
.all(|x| x.is_supported_list_arithmetic_input()) =>
684
{
685
polars_bail!(
686
InvalidOperation:
687
"cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
688
op, l, r,
689
)
690
},
691
// List<->primitive operations can be done directly after casting the to the primitive
692
// supertype for the primitive values on both sides.
693
(list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {
694
let dtype = list_dtype.cast_leaf(try_get_supertype(
695
list_dtype.leaf_dtype(),
696
other_dtype.leaf_dtype(),
697
)?);
698
left_field.coerce(dtype);
699
return Ok(left_field);
700
},
701
#[cfg(feature = "dtype-array")]
702
(list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {
703
let dtype = list_dtype.cast_leaf(try_get_supertype(
704
list_dtype.leaf_dtype(),
705
other_dtype.leaf_dtype(),
706
)?);
707
left_field.coerce(dtype);
708
return Ok(left_field);
709
},
710
_ => {
711
// Avoid needlessly type casting numeric columns during arithmetic
712
// with literals.
713
if (left_field.dtype.is_integer() && right_type.is_integer())
714
|| (left_field.dtype.is_float() && right_type.is_float())
715
{
716
match (left_ae, right_ae) {
717
(AExpr::Literal(_), AExpr::Literal(_)) => {},
718
(AExpr::Literal(_), _) if left_field.dtype.is_unknown() => {
719
// literal will be coerced to match right type
720
left_field.coerce(right_type);
721
return Ok(left_field);
722
},
723
(_, AExpr::Literal(_)) if right_type.is_unknown() => {
724
// literal will be coerced to match right type
725
return Ok(left_field);
726
},
727
_ => {},
728
}
729
}
730
},
731
}
732
733
try_get_supertype(&left_field.dtype, &right_type)?
734
},
735
};
736
737
left_field.coerce(super_type);
738
Ok(left_field)
739
}
740
741
fn get_truediv_field(left: Node, right: Node, ctx: &ToFieldContext) -> PolarsResult<Field> {
742
let mut left_field = ctx.arena.get(left).to_field_impl(ctx)?;
743
let right_field = ctx.arena.get(right).to_field_impl(ctx)?;
744
let out_type = get_truediv_dtype(left_field.dtype(), right_field.dtype())?;
745
left_field.coerce(out_type);
746
Ok(left_field)
747
}
748
749
fn get_truediv_dtype(left_dtype: &DataType, right_dtype: &DataType) -> PolarsResult<DataType> {
750
use DataType::*;
751
752
// TODO: Re-investigate this. A lot of "_" is being used on the RHS match because this code
753
// originally (mostly) only looked at the LHS dtype.
754
let out_type = match (left_dtype, right_dtype) {
755
#[cfg(feature = "dtype-struct")]
756
(Struct(a), Struct(b)) => {
757
polars_ensure!(a.len() == b.len() || b.len() == 1,
758
InvalidOperation: "cannot {} two structs of different length (left: {}, right: {})",
759
"div", a.len(), b.len()
760
);
761
let mut fields = Vec::with_capacity(a.len());
762
// In case b.len() == 1, we broadcast the first field (b[0]).
763
// Safety is assured by the constraints above.
764
let b_iter = (0..a.len()).map(|i| b.get(i.min(b.len() - 1)).unwrap());
765
for (left, right) in a.iter().zip(b_iter) {
766
let name = left.name.clone();
767
let (left, right) = (left.dtype(), right.dtype());
768
if !(left.is_numeric() && right.is_numeric()) {
769
polars_bail!(InvalidOperation:
770
"cannot {} two structs with non-numeric fields: (left: {}, right: {})",
771
"div", left, right,)
772
};
773
let field = Field::new(name, get_truediv_dtype(left, right)?);
774
fields.push(field);
775
}
776
Struct(fields)
777
},
778
#[cfg(feature = "dtype-struct")]
779
(Struct(a), n) if n.is_numeric() => {
780
let mut fields = Vec::with_capacity(a.len());
781
for left in a.iter() {
782
let name = left.name.clone();
783
let left = left.dtype();
784
if !(left.is_numeric()) {
785
polars_bail!(InvalidOperation:
786
"cannot {} a struct with non-numeric field: (left: {})",
787
"div", left)
788
};
789
let field = Field::new(name, get_truediv_dtype(left, n)?);
790
fields.push(field);
791
}
792
Struct(fields)
793
},
794
(l @ List(a), r @ List(b))
795
if ![a, b]
796
.into_iter()
797
.all(|x| x.is_supported_list_arithmetic_input()) =>
798
{
799
polars_bail!(
800
InvalidOperation:
801
"cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
802
"div", l, r,
803
)
804
},
805
(list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {
806
let dtype = get_truediv_dtype(list_dtype.leaf_dtype(), other_dtype.leaf_dtype())?;
807
list_dtype.cast_leaf(dtype)
808
},
809
#[cfg(feature = "dtype-array")]
810
(list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {
811
let dtype = get_truediv_dtype(list_dtype.leaf_dtype(), other_dtype.leaf_dtype())?;
812
list_dtype.cast_leaf(dtype)
813
},
814
#[cfg(feature = "dtype-f16")]
815
(Boolean, Float16) => Float16,
816
(Boolean, Float32) => Float32,
817
(Boolean, b) if b.is_numeric() => Float64,
818
(Boolean, Boolean) => Float64,
819
#[cfg(all(feature = "dtype-f16", feature = "dtype-u8"))]
820
(Float16, UInt8 | Int8) => Float16,
821
#[cfg(all(feature = "dtype-f16", feature = "dtype-u16"))]
822
(Float16, UInt16 | Int16) => Float32,
823
#[cfg(feature = "dtype-f16")]
824
(Float16, Unknown(UnknownKind::Int(_))) => Float16,
825
#[cfg(feature = "dtype-f16")]
826
(Float16, other) if other.is_integer() => Float64,
827
#[cfg(feature = "dtype-f16")]
828
(Float16, Float32) => Float32,
829
#[cfg(feature = "dtype-f16")]
830
(Float16, Float64) => Float64,
831
#[cfg(feature = "dtype-f16")]
832
(Float16, _) => Float16,
833
#[cfg(feature = "dtype-u8")]
834
(Float32, UInt8 | Int8) => Float32,
835
#[cfg(feature = "dtype-u16")]
836
(Float32, UInt16 | Int16) => Float32,
837
(Float32, Unknown(UnknownKind::Int(_))) => Float32,
838
(Float32, other) if other.is_integer() => Float64,
839
(Float32, Float64) => Float64,
840
(Float32, _) => Float32,
841
(String, _) | (_, String) => polars_bail!(
842
InvalidOperation: "division with 'String' datatypes is not allowed"
843
),
844
#[cfg(feature = "dtype-decimal")]
845
(Decimal(_, scale_left), Decimal(_, scale_right)) => {
846
Decimal(DEC128_MAX_PREC, *scale_left.max(scale_right))
847
},
848
#[cfg(all(feature = "dtype-u8", feature = "dtype-f16"))]
849
(UInt8 | Int8, Float16) => Float16,
850
#[cfg(all(feature = "dtype-u16", feature = "dtype-f16"))]
851
(UInt16 | Int16, Float16) => Float32,
852
#[cfg(feature = "dtype-u8")]
853
(UInt8 | Int8, Float32) => Float32,
854
#[cfg(feature = "dtype-u16")]
855
(UInt16 | Int16, Float32) => Float32,
856
(dt, _) if dt.is_primitive_numeric() => Float64,
857
#[cfg(feature = "dtype-duration")]
858
(Duration(_), Duration(_)) => Float64,
859
#[cfg(feature = "dtype-duration")]
860
(Duration(_), dt) if dt.is_primitive_numeric() => left_dtype.clone(),
861
#[cfg(feature = "dtype-duration")]
862
(Duration(_), dt) => {
863
polars_bail!(InvalidOperation: "true division of {} with {} is not allowed", left_dtype, dt)
864
},
865
#[cfg(feature = "dtype-datetime")]
866
(Datetime(_, _), _) => {
867
polars_bail!(InvalidOperation: "division of 'Datetime' datatype is not allowed")
868
},
869
#[cfg(feature = "dtype-time")]
870
(Time, _) => polars_bail!(InvalidOperation: "division of 'Time' datatype is not allowed"),
871
#[cfg(feature = "dtype-date")]
872
(Date, _) => polars_bail!(InvalidOperation: "division of 'Date' datatype is not allowed"),
873
// we don't know what to do here, best return the dtype
874
(dt, _) => dt.clone(),
875
};
876
Ok(out_type)
877
}
878
879