Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/schema.rs
6940 views
1
#[cfg(feature = "dtype-decimal")]
2
use polars_core::chunked_array::arithmetic::{
3
_get_decimal_scale_add_sub, _get_decimal_scale_div, _get_decimal_scale_mul,
4
};
5
use polars_utils::format_pl_smallstr;
6
use recursive::recursive;
7
8
use super::*;
9
10
fn validate_expr(node: Node, arena: &Arena<AExpr>, schema: &Schema) -> PolarsResult<()> {
11
let mut ctx = ToFieldContext {
12
schema,
13
arena,
14
validate: true,
15
};
16
arena.get(node).to_field_impl(&mut ctx).map(|_| ())
17
}
18
19
struct ToFieldContext<'a> {
20
schema: &'a Schema,
21
arena: &'a Arena<AExpr>,
22
// Traverse all expressions to validate they are in the schema.
23
validate: bool,
24
}
25
26
impl AExpr {
27
pub fn to_dtype(&self, schema: &Schema, arena: &Arena<AExpr>) -> PolarsResult<DataType> {
28
self.to_field(schema, arena).map(|f| f.dtype)
29
}
30
31
/// Get Field result of the expression. The schema is the input data. The provided
32
/// context will be used to coerce the type into a List if needed, also known as auto-implode.
33
pub fn to_field_with_ctx(
34
&self,
35
schema: &Schema,
36
ctx: Context,
37
arena: &Arena<AExpr>,
38
) -> PolarsResult<Field> {
39
// Indicates whether we should auto-implode the result. This is initialized to true if we are
40
// in an aggregation context, so functions that return scalars should explicitly set this
41
// to false in `to_field_impl`.
42
let agg_list = matches!(ctx, Context::Aggregation);
43
let mut ctx = ToFieldContext {
44
schema,
45
arena,
46
validate: true,
47
};
48
let mut field = self.to_field_impl(&mut ctx)?;
49
50
if agg_list {
51
if !self.is_scalar(arena) {
52
field.coerce(field.dtype().clone().implode());
53
}
54
}
55
56
Ok(field)
57
}
58
59
/// Get Field result of the expression. The schema is the input data. The result will
60
/// not be coerced (also known as auto-implode): this is the responsibility of the caller.
61
pub fn to_field(&self, schema: &Schema, arena: &Arena<AExpr>) -> PolarsResult<Field> {
62
let mut ctx = ToFieldContext {
63
schema,
64
arena,
65
validate: true,
66
};
67
68
let field = self.to_field_impl(&mut ctx)?;
69
70
Ok(field)
71
}
72
73
/// Get Field result of the expression. The schema is the input data.
74
///
75
/// This is taken as `&mut bool` as for some expressions this is determined by the upper node
76
/// (e.g. `alias`, `cast`).
77
#[recursive]
78
pub fn to_field_impl(&self, ctx: &mut ToFieldContext) -> PolarsResult<Field> {
79
use AExpr::*;
80
use DataType::*;
81
match self {
82
Len => Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE)),
83
Window {
84
function,
85
options,
86
partition_by,
87
order_by,
88
} => {
89
if ctx.validate {
90
for node in partition_by {
91
validate_expr(*node, ctx.arena, ctx.schema)?;
92
}
93
if let Some((node, _)) = order_by {
94
validate_expr(*node, ctx.arena, ctx.schema)?;
95
}
96
}
97
98
let e = ctx.arena.get(*function);
99
let mut field = e.to_field_impl(ctx)?;
100
101
let mut implicit_implode = false;
102
103
implicit_implode |= matches!(options, WindowType::Over(WindowMapping::Join));
104
#[cfg(feature = "dynamic_group_by")]
105
{
106
implicit_implode |= matches!(options, WindowType::Rolling(_));
107
}
108
109
if implicit_implode && !is_scalar_ae(*function, ctx.arena) {
110
field.dtype = field.dtype.implode();
111
}
112
113
Ok(field)
114
},
115
Explode { expr, .. } => {
116
let field = ctx.arena.get(*expr).to_field_impl(ctx)?;
117
let field = match field.dtype() {
118
List(inner) => Field::new(field.name().clone(), *inner.clone()),
119
#[cfg(feature = "dtype-array")]
120
Array(inner, ..) => Field::new(field.name().clone(), *inner.clone()),
121
_ => field,
122
};
123
124
Ok(field)
125
},
126
Column(name) => ctx
127
.schema
128
.get_field(name)
129
.ok_or_else(|| PolarsError::ColumnNotFound(name.to_string().into())),
130
Literal(sv) => Ok(match sv {
131
LiteralValue::Series(s) => s.field().into_owned(),
132
_ => Field::new(sv.output_column_name().clone(), sv.get_datatype()),
133
}),
134
BinaryExpr { left, right, op } => {
135
use DataType::*;
136
137
let field = match op {
138
Operator::Lt
139
| Operator::Gt
140
| Operator::Eq
141
| Operator::NotEq
142
| Operator::LogicalAnd
143
| Operator::LtEq
144
| Operator::GtEq
145
| Operator::NotEqValidity
146
| Operator::EqValidity
147
| Operator::LogicalOr => {
148
let out_field;
149
let out_name = {
150
out_field = ctx.arena.get(*left).to_field_impl(ctx)?;
151
out_field.name()
152
};
153
Field::new(out_name.clone(), Boolean)
154
},
155
Operator::TrueDivide => get_truediv_field(*left, *right, ctx)?,
156
_ => get_arithmetic_field(*left, *right, *op, ctx)?,
157
};
158
159
Ok(field)
160
},
161
Sort { expr, .. } => ctx.arena.get(*expr).to_field_impl(ctx),
162
Gather { expr, idx, .. } => {
163
if ctx.validate {
164
validate_expr(*idx, ctx.arena, ctx.schema)?
165
}
166
ctx.arena.get(*expr).to_field_impl(ctx)
167
},
168
SortBy { expr, .. } => ctx.arena.get(*expr).to_field_impl(ctx),
169
Filter { input, by } => {
170
if ctx.validate {
171
validate_expr(*by, ctx.arena, ctx.schema)?
172
}
173
ctx.arena.get(*input).to_field_impl(ctx)
174
},
175
Agg(agg) => {
176
use IRAggExpr::*;
177
match agg {
178
Max { input: expr, .. }
179
| Min { input: expr, .. }
180
| First(expr)
181
| Last(expr) => ctx.arena.get(*expr).to_field_impl(ctx),
182
Sum(expr) => {
183
let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;
184
let dt = match field.dtype() {
185
Boolean => Some(IDX_DTYPE),
186
UInt8 | Int8 | Int16 | UInt16 => Some(Int64),
187
_ => None,
188
};
189
if let Some(dt) = dt {
190
field.coerce(dt);
191
}
192
Ok(field)
193
},
194
Median(expr) => {
195
let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;
196
match field.dtype {
197
Date => field.coerce(Datetime(TimeUnit::Microseconds, None)),
198
_ => {
199
let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];
200
let mapper = FieldsMapper::new(&field);
201
return mapper.moment_dtype();
202
},
203
}
204
Ok(field)
205
},
206
Mean(expr) => {
207
let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;
208
match field.dtype {
209
Date => field.coerce(Datetime(TimeUnit::Microseconds, None)),
210
_ => {
211
let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];
212
let mapper = FieldsMapper::new(&field);
213
return mapper.moment_dtype();
214
},
215
}
216
Ok(field)
217
},
218
Implode(expr) => {
219
let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;
220
field.coerce(DataType::List(field.dtype().clone().into()));
221
Ok(field)
222
},
223
Std(expr, _) => {
224
let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];
225
let mapper = FieldsMapper::new(&field);
226
mapper.moment_dtype()
227
},
228
Var(expr, _) => {
229
let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];
230
let mapper = FieldsMapper::new(&field);
231
mapper.var_dtype()
232
},
233
NUnique(expr) => {
234
let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;
235
field.coerce(IDX_DTYPE);
236
Ok(field)
237
},
238
Count { input, .. } => {
239
let mut field = ctx.arena.get(*input).to_field_impl(ctx)?;
240
field.coerce(IDX_DTYPE);
241
Ok(field)
242
},
243
AggGroups(expr) => {
244
let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;
245
field.coerce(IDX_DTYPE.implode());
246
Ok(field)
247
},
248
Quantile { expr, .. } => {
249
let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];
250
let mapper = FieldsMapper::new(&field);
251
mapper.map_numeric_to_float_dtype(true)
252
},
253
}
254
},
255
Cast { expr, dtype, .. } => {
256
let field = ctx.arena.get(*expr).to_field_impl(ctx)?;
257
Ok(Field::new(field.name().clone(), dtype.clone()))
258
},
259
Ternary { truthy, falsy, .. } => {
260
// During aggregation:
261
// left: col(foo): list<T> nesting: 1
262
// right; col(foo).first(): T nesting: 0
263
// col(foo) + col(foo).first() will have nesting 1 as we still maintain the groups list.
264
let mut truthy = ctx.arena.get(*truthy).to_field_impl(ctx)?;
265
let falsy = ctx.arena.get(*falsy).to_field_impl(ctx)?;
266
267
let st = if let DataType::Null = *truthy.dtype() {
268
falsy.dtype().clone()
269
} else {
270
try_get_supertype(truthy.dtype(), falsy.dtype())?
271
};
272
273
truthy.coerce(st);
274
Ok(truthy)
275
},
276
AnonymousFunction {
277
input,
278
function,
279
fmt_str,
280
..
281
} => {
282
let fields = func_args_to_fields(input, ctx)?;
283
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", fmt_str);
284
let function = function.clone().materialize()?;
285
let out = function.get_field(ctx.schema, &fields)?;
286
Ok(out)
287
},
288
Eval {
289
expr,
290
evaluation,
291
variant,
292
} => {
293
let field = ctx.arena.get(*expr).to_field_impl(ctx)?;
294
295
let element_dtype = variant.element_dtype(field.dtype())?;
296
let schema = Schema::from_iter([(PlSmallStr::EMPTY, element_dtype.clone())]);
297
298
let mut ctx = ToFieldContext {
299
schema: &schema,
300
arena: ctx.arena,
301
validate: ctx.validate,
302
};
303
let mut output_field = ctx.arena.get(*evaluation).to_field_impl(&mut ctx)?;
304
output_field.dtype = output_field.dtype.materialize_unknown(false)?;
305
306
output_field.dtype = match variant {
307
EvalVariant::List => DataType::List(Box::new(output_field.dtype)),
308
EvalVariant::Cumulative { .. } => output_field.dtype,
309
};
310
output_field.name = field.name;
311
312
Ok(output_field)
313
},
314
Function {
315
function,
316
input,
317
options: _,
318
} => {
319
let fields = func_args_to_fields(input, ctx)?;
320
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function);
321
let out = function.get_field(ctx.schema, &fields)?;
322
323
Ok(out)
324
},
325
Slice {
326
input,
327
offset,
328
length,
329
} => {
330
if ctx.validate {
331
validate_expr(*offset, ctx.arena, ctx.schema)?;
332
validate_expr(*length, ctx.arena, ctx.schema)?;
333
}
334
335
ctx.arena.get(*input).to_field_impl(ctx)
336
},
337
}
338
}
339
340
pub fn to_name(&self, expr_arena: &Arena<AExpr>) -> PlSmallStr {
341
use AExpr::*;
342
use IRAggExpr::*;
343
match self {
344
Len => crate::constants::get_len_name(),
345
Window {
346
function: expr,
347
options: _,
348
partition_by: _,
349
order_by: _,
350
}
351
| BinaryExpr { left: expr, .. }
352
| Explode { expr, .. }
353
| Sort { expr, .. }
354
| Gather { expr, .. }
355
| SortBy { expr, .. }
356
| Filter { input: expr, .. }
357
| Cast { expr, .. }
358
| Ternary { truthy: expr, .. }
359
| Eval { expr, .. }
360
| Slice { input: expr, .. }
361
| Agg(Max { input: expr, .. })
362
| Agg(Min { input: expr, .. })
363
| Agg(First(expr))
364
| Agg(Last(expr))
365
| Agg(Sum(expr))
366
| Agg(Median(expr))
367
| Agg(Mean(expr))
368
| Agg(Implode(expr))
369
| Agg(Std(expr, _))
370
| Agg(Var(expr, _))
371
| Agg(NUnique(expr))
372
| Agg(Count { input: expr, .. })
373
| Agg(AggGroups(expr))
374
| Agg(Quantile { expr, .. }) => expr_arena.get(*expr).to_name(expr_arena),
375
AnonymousFunction { input, fmt_str, .. } => {
376
if input.is_empty() {
377
fmt_str.as_ref().clone()
378
} else {
379
input[0].output_name().clone()
380
}
381
},
382
Function {
383
input, function, ..
384
} => match function.output_name().and_then(|v| v.into_inner()) {
385
Some(name) => name,
386
None if input.is_empty() => format_pl_smallstr!("{}", &function),
387
None => input[0].output_name().clone(),
388
},
389
Column(name) => name.clone(),
390
Literal(lv) => lv.output_column_name().clone(),
391
}
392
}
393
}
394
395
fn func_args_to_fields(input: &[ExprIR], ctx: &mut ToFieldContext) -> PolarsResult<Vec<Field>> {
396
input
397
.iter()
398
.map(|e| {
399
ctx.arena.get(e.node()).to_field_impl(ctx).map(|mut field| {
400
field.name = e.output_name().clone();
401
field
402
})
403
})
404
.collect()
405
}
406
407
#[allow(clippy::too_many_arguments)]
408
fn get_arithmetic_field(
409
left: Node,
410
right: Node,
411
op: Operator,
412
ctx: &mut ToFieldContext,
413
) -> PolarsResult<Field> {
414
use DataType::*;
415
let left_ae = ctx.arena.get(left);
416
let right_ae = ctx.arena.get(right);
417
418
// don't traverse tree until strictly needed. Can have terrible performance.
419
// # 3210
420
421
// take the left field as a whole.
422
// don't take dtype and name separate as that splits the tree every node
423
// leading to quadratic behavior. # 4736
424
//
425
// further right_type is only determined when needed.
426
let mut left_field = left_ae.to_field_impl(ctx)?;
427
428
let super_type = match op {
429
Operator::Minus => {
430
let right_type = right_ae.to_field_impl(ctx)?.dtype;
431
match (&left_field.dtype, &right_type) {
432
#[cfg(feature = "dtype-struct")]
433
(Struct(_), Struct(_)) => {
434
return Ok(left_field);
435
},
436
// This matches the engine output. TODO: revisit pending resolution of GH issue #23797
437
#[cfg(feature = "dtype-struct")]
438
(Struct(_), r) if r.is_numeric() => {
439
return Ok(left_field);
440
},
441
(Duration(_), Datetime(_, _))
442
| (Datetime(_, _), Duration(_))
443
| (Duration(_), Date)
444
| (Date, Duration(_))
445
| (Duration(_), Time)
446
| (Time, Duration(_)) => try_get_supertype(left_field.dtype(), &right_type)?,
447
(Datetime(tu, _), Date) | (Date, Datetime(tu, _)) => Duration(*tu),
448
// T - T != T if T is a datetime / date
449
(Datetime(tul, _), Datetime(tur, _)) => Duration(get_time_units(tul, tur)),
450
(_, Datetime(_, _)) | (Datetime(_, _), _) => {
451
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
452
},
453
(Date, Date) => Duration(TimeUnit::Microseconds),
454
(_, Date) | (Date, _) => {
455
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
456
},
457
(Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)),
458
(_, Duration(_)) | (Duration(_), _) => {
459
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
460
},
461
(Time, Time) => Duration(TimeUnit::Nanoseconds),
462
(_, Time) | (Time, _) => {
463
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
464
},
465
(l @ List(a), r @ List(b))
466
if ![a, b]
467
.into_iter()
468
.all(|x| x.is_supported_list_arithmetic_input()) =>
469
{
470
polars_bail!(
471
InvalidOperation:
472
"cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
473
"sub", l, r,
474
)
475
},
476
(list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {
477
// FIXME: This should not use `try_get_supertype()`! It should instead recursively use the enclosing match block.
478
// Otherwise we will silently permit addition operations between logical types (see above).
479
// This currently doesn't cause any problems because the list arithmetic implementation checks and raises errors
480
// if the leaf types aren't numeric, but it means we don't raise an error until execution and the DSL schema
481
// may be incorrect.
482
list_dtype.cast_leaf(try_get_supertype(
483
list_dtype.leaf_dtype(),
484
other_dtype.leaf_dtype(),
485
)?)
486
},
487
#[cfg(feature = "dtype-array")]
488
(list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {
489
list_dtype.cast_leaf(try_get_supertype(
490
list_dtype.leaf_dtype(),
491
other_dtype.leaf_dtype(),
492
)?)
493
},
494
#[cfg(feature = "dtype-decimal")]
495
(Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => {
496
let scale = _get_decimal_scale_add_sub(*scale_left, *scale_right);
497
Decimal(None, Some(scale))
498
},
499
(left, right) => try_get_supertype(left, right)?,
500
}
501
},
502
Operator::Plus => {
503
let right_type = right_ae.to_field_impl(ctx)?.dtype;
504
match (&left_field.dtype, &right_type) {
505
#[cfg(feature = "dtype-struct")]
506
(Struct(_), Struct(_)) => {
507
return Ok(left_field);
508
},
509
// This matches the engine output. TODO: revisit pending resolution of GH issue #23797
510
#[cfg(feature = "dtype-struct")]
511
(Struct(_), r) if r.is_numeric() => {
512
return Ok(left_field);
513
},
514
(Duration(_), Datetime(_, _))
515
| (Datetime(_, _), Duration(_))
516
| (Duration(_), Date)
517
| (Date, Duration(_))
518
| (Duration(_), Time)
519
| (Time, Duration(_)) => try_get_supertype(left_field.dtype(), &right_type)?,
520
(_, Datetime(_, _))
521
| (Datetime(_, _), _)
522
| (_, Date)
523
| (Date, _)
524
| (Time, _)
525
| (_, Time) => {
526
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
527
},
528
(Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)),
529
(_, Duration(_)) | (Duration(_), _) => {
530
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
531
},
532
(Boolean, Boolean) => IDX_DTYPE,
533
(l @ List(a), r @ List(b))
534
if ![a, b]
535
.into_iter()
536
.all(|x| x.is_supported_list_arithmetic_input()) =>
537
{
538
polars_bail!(
539
InvalidOperation:
540
"cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
541
"add", l, r,
542
)
543
},
544
(list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {
545
list_dtype.cast_leaf(try_get_supertype(
546
list_dtype.leaf_dtype(),
547
other_dtype.leaf_dtype(),
548
)?)
549
},
550
#[cfg(feature = "dtype-array")]
551
(list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {
552
list_dtype.cast_leaf(try_get_supertype(
553
list_dtype.leaf_dtype(),
554
other_dtype.leaf_dtype(),
555
)?)
556
},
557
#[cfg(feature = "dtype-decimal")]
558
(Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => {
559
let scale = _get_decimal_scale_add_sub(*scale_left, *scale_right);
560
Decimal(None, Some(scale))
561
},
562
(left, right) => try_get_supertype(left, right)?,
563
}
564
},
565
_ => {
566
let right_type = right_ae.to_field_impl(ctx)?.dtype;
567
568
match (&left_field.dtype, &right_type) {
569
#[cfg(feature = "dtype-struct")]
570
(Struct(_), Struct(_)) => {
571
return Ok(left_field);
572
},
573
// This matches the engine output. TODO: revisit pending resolution of GH issue #23797
574
#[cfg(feature = "dtype-struct")]
575
(Struct(_), r) if r.is_numeric() => {
576
return Ok(left_field);
577
},
578
(Datetime(_, _), _)
579
| (_, Datetime(_, _))
580
| (Time, _)
581
| (_, Time)
582
| (Date, _)
583
| (_, Date) => {
584
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
585
},
586
(Duration(_), Duration(_)) => {
587
// True divide handled somewhere else
588
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
589
},
590
(l, Duration(_)) if l.is_primitive_numeric() => match op {
591
Operator::Multiply => {
592
left_field.coerce(right_type);
593
return Ok(left_field);
594
},
595
_ => {
596
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
597
},
598
},
599
(Duration(_), r) if r.is_primitive_numeric() => match op {
600
Operator::Multiply => {
601
return Ok(left_field);
602
},
603
_ => {
604
polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
605
},
606
},
607
#[cfg(feature = "dtype-decimal")]
608
(Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => {
609
let scale = match op {
610
Operator::Multiply => _get_decimal_scale_mul(*scale_left, *scale_right),
611
Operator::Divide | Operator::TrueDivide => {
612
_get_decimal_scale_div(*scale_left)
613
},
614
_ => {
615
debug_assert!(false);
616
*scale_left
617
},
618
};
619
let dtype = Decimal(None, Some(scale));
620
left_field.coerce(dtype);
621
return Ok(left_field);
622
},
623
624
(l @ List(a), r @ List(b))
625
if ![a, b]
626
.into_iter()
627
.all(|x| x.is_supported_list_arithmetic_input()) =>
628
{
629
polars_bail!(
630
InvalidOperation:
631
"cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
632
op, l, r,
633
)
634
},
635
// List<->primitive operations can be done directly after casting the to the primitive
636
// supertype for the primitive values on both sides.
637
(list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {
638
let dtype = list_dtype.cast_leaf(try_get_supertype(
639
list_dtype.leaf_dtype(),
640
other_dtype.leaf_dtype(),
641
)?);
642
left_field.coerce(dtype);
643
return Ok(left_field);
644
},
645
#[cfg(feature = "dtype-array")]
646
(list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {
647
let dtype = list_dtype.cast_leaf(try_get_supertype(
648
list_dtype.leaf_dtype(),
649
other_dtype.leaf_dtype(),
650
)?);
651
left_field.coerce(dtype);
652
return Ok(left_field);
653
},
654
_ => {
655
// Avoid needlessly type casting numeric columns during arithmetic
656
// with literals.
657
if (left_field.dtype.is_integer() && right_type.is_integer())
658
|| (left_field.dtype.is_float() && right_type.is_float())
659
{
660
match (left_ae, right_ae) {
661
(AExpr::Literal(_), AExpr::Literal(_)) => {},
662
(AExpr::Literal(_), _) => {
663
// literal will be coerced to match right type
664
left_field.coerce(right_type);
665
return Ok(left_field);
666
},
667
(_, AExpr::Literal(_)) => {
668
// literal will be coerced to match right type
669
return Ok(left_field);
670
},
671
_ => {},
672
}
673
}
674
},
675
}
676
677
try_get_supertype(&left_field.dtype, &right_type)?
678
},
679
};
680
681
left_field.coerce(super_type);
682
Ok(left_field)
683
}
684
685
fn get_truediv_field(left: Node, right: Node, ctx: &mut ToFieldContext) -> PolarsResult<Field> {
686
let mut left_field = ctx.arena.get(left).to_field_impl(ctx)?;
687
let right_field = ctx.arena.get(right).to_field_impl(ctx)?;
688
let out_type = get_truediv_dtype(left_field.dtype(), right_field.dtype())?;
689
left_field.coerce(out_type);
690
Ok(left_field)
691
}
692
693
fn get_truediv_dtype(left_dtype: &DataType, right_dtype: &DataType) -> PolarsResult<DataType> {
694
use DataType::*;
695
696
// TODO: Re-investigate this. A lot of "_" is being used on the RHS match because this code
697
// originally (mostly) only looked at the LHS dtype.
698
let out_type = match (left_dtype, right_dtype) {
699
#[cfg(feature = "dtype-struct")]
700
(Struct(a), Struct(b)) => {
701
polars_ensure!(a.len() == b.len() || b.len() == 1,
702
InvalidOperation: "cannot {} two structs of different length (left: {}, right: {})",
703
"div", a.len(), b.len()
704
);
705
let mut fields = Vec::with_capacity(a.len());
706
// In case b.len() == 1, we broadcast the first field (b[0]).
707
// Safety is assured by the constraints above.
708
let b_iter = (0..a.len()).map(|i| b.get(i.min(b.len() - 1)).unwrap());
709
for (left, right) in a.iter().zip(b_iter) {
710
let name = left.name.clone();
711
let (left, right) = (left.dtype(), right.dtype());
712
if !(left.is_numeric() && right.is_numeric()) {
713
polars_bail!(InvalidOperation:
714
"cannot {} two structs with non-numeric fields: (left: {}, right: {})",
715
"div", left, right,)
716
};
717
let field = Field::new(name, get_truediv_dtype(left, right)?);
718
fields.push(field);
719
}
720
Struct(fields)
721
},
722
#[cfg(feature = "dtype-struct")]
723
(Struct(a), n) if n.is_numeric() => {
724
let mut fields = Vec::with_capacity(a.len());
725
for left in a.iter() {
726
let name = left.name.clone();
727
let left = left.dtype();
728
if !(left.is_numeric()) {
729
polars_bail!(InvalidOperation:
730
"cannot {} a struct with non-numeric field: (left: {})",
731
"div", left)
732
};
733
let field = Field::new(name, get_truediv_dtype(left, n)?);
734
fields.push(field);
735
}
736
Struct(fields)
737
},
738
(l @ List(a), r @ List(b))
739
if ![a, b]
740
.into_iter()
741
.all(|x| x.is_supported_list_arithmetic_input()) =>
742
{
743
polars_bail!(
744
InvalidOperation:
745
"cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
746
"div", l, r,
747
)
748
},
749
(list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {
750
let dtype = get_truediv_dtype(list_dtype.leaf_dtype(), other_dtype.leaf_dtype())?;
751
list_dtype.cast_leaf(dtype)
752
},
753
#[cfg(feature = "dtype-array")]
754
(list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {
755
let dtype = get_truediv_dtype(list_dtype.leaf_dtype(), other_dtype.leaf_dtype())?;
756
list_dtype.cast_leaf(dtype)
757
},
758
(Boolean, Float32) => Float32,
759
(Boolean, b) if b.is_numeric() => Float64,
760
(Boolean, Boolean) => Float64,
761
#[cfg(feature = "dtype-u8")]
762
(Float32, UInt8 | Int8) => Float32,
763
#[cfg(feature = "dtype-u16")]
764
(Float32, UInt16 | Int16) => Float32,
765
(Float32, other) if other.is_integer() => Float64,
766
(Float32, Float64) => Float64,
767
(Float32, _) => Float32,
768
(String, _) | (_, String) => polars_bail!(
769
InvalidOperation: "division with 'String' datatypes is not allowed"
770
),
771
#[cfg(feature = "dtype-decimal")]
772
(Decimal(_, Some(scale_left)), Decimal(_, _)) => {
773
let scale = _get_decimal_scale_div(*scale_left);
774
Decimal(None, Some(scale))
775
},
776
#[cfg(feature = "dtype-u8")]
777
(UInt8 | Int8, Float32) => Float32,
778
#[cfg(feature = "dtype-u16")]
779
(UInt16 | Int16, Float32) => Float32,
780
(dt, _) if dt.is_primitive_numeric() => Float64,
781
#[cfg(feature = "dtype-duration")]
782
(Duration(_), Duration(_)) => Float64,
783
#[cfg(feature = "dtype-duration")]
784
(Duration(_), dt) if dt.is_primitive_numeric() => left_dtype.clone(),
785
#[cfg(feature = "dtype-duration")]
786
(Duration(_), dt) => {
787
polars_bail!(InvalidOperation: "true division of {} with {} is not allowed", left_dtype, dt)
788
},
789
#[cfg(feature = "dtype-datetime")]
790
(Datetime(_, _), _) => {
791
polars_bail!(InvalidOperation: "division of 'Datetime' datatype is not allowed")
792
},
793
#[cfg(feature = "dtype-time")]
794
(Time, _) => polars_bail!(InvalidOperation: "division of 'Time' datatype is not allowed"),
795
#[cfg(feature = "dtype-date")]
796
(Date, _) => polars_bail!(InvalidOperation: "division of 'Date' datatype is not allowed"),
797
// we don't know what to do here, best return the dtype
798
(dt, _) => dt.clone(),
799
};
800
Ok(out_type)
801
}
802
803