Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/function_expr/schema.rs
8384 views
1
use polars_core::utils::materialize_dyn_int;
2
3
use super::*;
4
5
impl IRFunctionExpr {
6
pub(crate) fn get_field(
7
&self,
8
_input_schema: &Schema,
9
fields: &[Field],
10
) -> PolarsResult<Field> {
11
use IRFunctionExpr::*;
12
13
let mapper = FieldsMapper { fields };
14
match self {
15
// Namespaces
16
#[cfg(feature = "dtype-array")]
17
ArrayExpr(func) => func.get_field(mapper),
18
BinaryExpr(s) => s.get_field(mapper),
19
#[cfg(feature = "dtype-categorical")]
20
Categorical(func) => func.get_field(mapper),
21
#[cfg(feature = "dtype-extension")]
22
Extension(func) => func.get_field(mapper),
23
ListExpr(func) => func.get_field(mapper),
24
#[cfg(feature = "strings")]
25
StringExpr(s) => s.get_field(mapper),
26
#[cfg(feature = "dtype-struct")]
27
StructExpr(s) => s.get_field(mapper),
28
#[cfg(feature = "temporal")]
29
TemporalExpr(fun) => fun.get_field(mapper),
30
#[cfg(feature = "bitwise")]
31
Bitwise(fun) => fun.get_field(mapper),
32
33
// Other expressions
34
Boolean(func) => func.get_field(mapper),
35
#[cfg(feature = "business")]
36
Business(func) => func.get_field(mapper),
37
#[cfg(feature = "abs")]
38
Abs => mapper.with_same_dtype(),
39
Negate => mapper.with_same_dtype(),
40
NullCount => mapper.with_dtype(IDX_DTYPE),
41
Pow(pow_function) => match pow_function {
42
IRPowFunction::Generic => mapper.pow_dtype(),
43
_ => mapper.map_numeric_to_float_dtype(true),
44
},
45
Coalesce => mapper.map_to_supertype(),
46
#[cfg(feature = "row_hash")]
47
Hash(..) => mapper.with_dtype(DataType::UInt64),
48
#[cfg(feature = "arg_where")]
49
ArgWhere => mapper.with_dtype(IDX_DTYPE),
50
#[cfg(feature = "index_of")]
51
IndexOf => mapper.with_dtype(IDX_DTYPE),
52
#[cfg(feature = "search_sorted")]
53
SearchSorted { .. } => mapper.with_dtype(IDX_DTYPE),
54
#[cfg(feature = "range")]
55
Range(func) => func.get_field(mapper),
56
#[cfg(feature = "trigonometry")]
57
Trigonometry(_) => mapper.map_to_float_dtype(),
58
#[cfg(feature = "trigonometry")]
59
Atan2 => mapper.map_to_float_dtype(),
60
#[cfg(feature = "sign")]
61
Sign => mapper
62
.ensure_satisfies(|_, dtype| dtype.is_numeric(), "sign")?
63
.with_same_dtype(),
64
FillNull => mapper.map_to_supertype(),
65
#[cfg(feature = "rolling_window")]
66
RollingExpr { function, options } => {
67
use IRRollingFunction::*;
68
match function {
69
Min | Max => mapper.with_same_dtype(),
70
Mean | Quantile | Std => mapper.moment_dtype(),
71
Var => mapper.var_dtype(),
72
Sum => mapper.sum_dtype(),
73
Rank => match options.fn_params {
74
Some(RollingFnParams::Rank {
75
method: RollingRankMethod::Average,
76
..
77
}) => mapper.with_dtype(DataType::Float64),
78
Some(RollingFnParams::Rank { .. }) => mapper.with_dtype(IDX_DTYPE),
79
_ => unreachable!("should be Some(RollingFnParams::Rank)"),
80
},
81
#[cfg(feature = "cov")]
82
CorrCov { .. } => mapper.map_to_float_dtype(),
83
#[cfg(feature = "moment")]
84
Skew | Kurtosis => mapper.map_to_float_dtype(),
85
Map(_) => mapper.try_map_field(|field| {
86
if options.weights.is_some() {
87
let dtype = match field.dtype() {
88
#[cfg(feature = "dtype-f16")]
89
DataType::Float16 => DataType::Float16,
90
DataType::Float32 => DataType::Float32,
91
_ => DataType::Float64,
92
};
93
Ok(Field::new(field.name().clone(), dtype))
94
} else {
95
Ok(field.clone())
96
}
97
}),
98
}
99
},
100
#[cfg(feature = "rolling_window_by")]
101
RollingExprBy {
102
function_by,
103
options,
104
..
105
} => {
106
use IRRollingFunctionBy::*;
107
match function_by {
108
MinBy | MaxBy => mapper.with_same_dtype(),
109
MeanBy | QuantileBy | StdBy => mapper.moment_dtype(),
110
VarBy => mapper.var_dtype(),
111
SumBy => mapper.sum_dtype(),
112
RankBy => match options.fn_params {
113
Some(RollingFnParams::Rank {
114
method: RollingRankMethod::Average,
115
..
116
}) => mapper.with_dtype(DataType::Float64),
117
Some(RollingFnParams::Rank { .. }) => mapper.with_dtype(IDX_DTYPE),
118
_ => unreachable!("should be Some(RollingFnParams::Rank)"),
119
},
120
}
121
},
122
Rechunk => mapper.with_same_dtype(),
123
Append { upcast } => {
124
if *upcast {
125
mapper.map_to_supertype()
126
} else {
127
mapper.with_same_dtype()
128
}
129
},
130
ShiftAndFill => mapper.with_same_dtype(),
131
DropNans => mapper.with_same_dtype(),
132
DropNulls => mapper.with_same_dtype(),
133
#[cfg(feature = "round_series")]
134
Clip {
135
has_min: _,
136
has_max: _,
137
} => mapper.with_same_dtype(),
138
#[cfg(feature = "mode")]
139
Mode { maintain_order: _ } => mapper.with_same_dtype(),
140
#[cfg(feature = "moment")]
141
Skew(_) => mapper.with_dtype(DataType::Float64),
142
#[cfg(feature = "moment")]
143
Kurtosis(..) => mapper.with_dtype(DataType::Float64),
144
ArgUnique | ArgMin | ArgMax | ArgSort { .. } => mapper.with_dtype(IDX_DTYPE),
145
MinBy | MaxBy => mapper.with_same_dtype(),
146
Product => mapper.map_dtype(|dtype| {
147
use DataType as T;
148
match dtype {
149
#[cfg(feature = "dtype-f16")]
150
T::Float16 => T::Float16,
151
T::Float32 => T::Float32,
152
T::Float64 => T::Float64,
153
T::UInt64 => T::UInt64,
154
#[cfg(feature = "dtype-i128")]
155
T::Int128 => T::Int128,
156
_ => T::Int64,
157
}
158
}),
159
Repeat => mapper.with_same_dtype(),
160
#[cfg(feature = "rank")]
161
Rank { options, .. } => mapper.with_dtype(match options.method {
162
RankMethod::Average => DataType::Float64,
163
_ => IDX_DTYPE,
164
}),
165
#[cfg(feature = "dtype-struct")]
166
AsStruct => {
167
let mut field_names = PlHashSet::with_capacity(fields.len() - 1);
168
let struct_fields = fields
169
.iter()
170
.map(|f| {
171
polars_ensure!(
172
field_names.insert(f.name.as_str()),
173
duplicate_field = f.name()
174
);
175
Ok(f.clone())
176
})
177
.collect::<PolarsResult<Vec<_>>>()?;
178
Ok(Field::new(
179
fields[0].name().clone(),
180
DataType::Struct(struct_fields),
181
))
182
},
183
#[cfg(feature = "top_k")]
184
TopK { .. } => mapper.with_same_dtype(),
185
#[cfg(feature = "top_k")]
186
TopKBy { .. } => mapper.with_same_dtype(),
187
#[cfg(feature = "dtype-struct")]
188
ValueCounts {
189
sort: _,
190
parallel: _,
191
name,
192
normalize,
193
} => mapper.map_dtype(|dt| {
194
let count_dt = if *normalize {
195
DataType::Float64
196
} else {
197
IDX_DTYPE
198
};
199
DataType::Struct(vec![
200
Field::new(fields[0].name().clone(), dt.clone()),
201
Field::new(name.clone(), count_dt),
202
])
203
}),
204
#[cfg(feature = "unique_counts")]
205
UniqueCounts => mapper.with_dtype(IDX_DTYPE),
206
Shift | Reverse => mapper.with_same_dtype(),
207
#[cfg(feature = "cum_agg")]
208
CumCount { .. } => mapper.with_dtype(IDX_DTYPE),
209
#[cfg(feature = "cum_agg")]
210
CumSum { .. } => mapper.map_dtype(cum::dtypes::cum_sum),
211
#[cfg(feature = "cum_agg")]
212
CumProd { .. } => mapper.map_dtype(cum::dtypes::cum_prod),
213
#[cfg(feature = "cum_agg")]
214
CumMin { .. } => mapper.with_same_dtype(),
215
#[cfg(feature = "cum_agg")]
216
CumMax { .. } => mapper.with_same_dtype(),
217
#[cfg(feature = "approx_unique")]
218
ApproxNUnique => mapper.with_dtype(IDX_DTYPE),
219
#[cfg(feature = "hist")]
220
Hist {
221
include_category,
222
include_breakpoint,
223
..
224
} => {
225
if *include_breakpoint || *include_category {
226
let mut fields = Vec::with_capacity(3);
227
if *include_breakpoint {
228
fields.push(Field::new(
229
PlSmallStr::from_static("breakpoint"),
230
DataType::Float64,
231
));
232
}
233
if *include_category {
234
fields.push(Field::new(
235
PlSmallStr::from_static("category"),
236
DataType::from_categories(Categories::global()),
237
));
238
}
239
fields.push(Field::new(PlSmallStr::from_static("count"), IDX_DTYPE));
240
mapper.with_dtype(DataType::Struct(fields))
241
} else {
242
mapper.with_dtype(IDX_DTYPE)
243
}
244
},
245
#[cfg(feature = "diff")]
246
Diff(_) => mapper.map_dtype(|dt| match dt {
247
#[cfg(feature = "dtype-datetime")]
248
DataType::Datetime(tu, _) => DataType::Duration(*tu),
249
#[cfg(feature = "dtype-date")]
250
DataType::Date => DataType::Duration(TimeUnit::Microseconds),
251
#[cfg(feature = "dtype-time")]
252
DataType::Time => DataType::Duration(TimeUnit::Nanoseconds),
253
DataType::UInt64 | DataType::UInt32 => DataType::Int64,
254
DataType::UInt16 => DataType::Int32,
255
DataType::UInt8 => DataType::Int16,
256
#[cfg(feature = "dtype-decimal")]
257
DataType::Decimal(_, scale) => {
258
DataType::Decimal(polars_compute::decimal::DEC128_MAX_PREC, *scale)
259
},
260
dt => dt.clone(),
261
}),
262
#[cfg(feature = "pct_change")]
263
PctChange => mapper.map_dtype(|dt| match dt {
264
#[cfg(feature = "dtype-f16")]
265
DataType::Float16 => dt.clone(),
266
DataType::Float32 => dt.clone(),
267
_ => DataType::Float64,
268
}),
269
#[cfg(feature = "interpolate")]
270
Interpolate(method) => match method {
271
InterpolationMethod::Linear => mapper.map_numeric_to_float_dtype(false),
272
InterpolationMethod::Nearest => mapper.with_same_dtype(),
273
},
274
#[cfg(feature = "interpolate_by")]
275
InterpolateBy => mapper.map_numeric_to_float_dtype(true),
276
#[cfg(feature = "log")]
277
Entropy { .. } | Log1p | Exp => mapper.map_to_float_dtype(),
278
#[cfg(feature = "log")]
279
Log => mapper.log_dtype(),
280
Unique(_) => mapper.with_same_dtype(),
281
#[cfg(feature = "round_series")]
282
Round { .. } | RoundSF { .. } | Floor | Ceil => mapper.with_same_dtype(),
283
#[cfg(feature = "fused")]
284
Fused(_) => mapper.map_to_supertype(),
285
ConcatExpr(_) => mapper.map_to_supertype(),
286
#[cfg(feature = "cov")]
287
Correlation { .. } => mapper.map_to_float_dtype(),
288
#[cfg(feature = "peaks")]
289
PeakMin | PeakMax => mapper.with_dtype(DataType::Boolean),
290
#[cfg(feature = "cutqcut")]
291
Cut {
292
include_breaks: false,
293
..
294
} => mapper.with_dtype(DataType::from_categories(Categories::global())),
295
#[cfg(feature = "cutqcut")]
296
Cut {
297
include_breaks: true,
298
..
299
} => {
300
let struct_dt = DataType::Struct(vec![
301
Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64),
302
Field::new(
303
PlSmallStr::from_static("category"),
304
DataType::from_categories(Categories::global()),
305
),
306
]);
307
mapper.with_dtype(struct_dt)
308
},
309
#[cfg(feature = "repeat_by")]
310
RepeatBy => mapper.map_dtype(|dt| DataType::List(dt.clone().into())),
311
#[cfg(feature = "dtype-array")]
312
Reshape(dims) => mapper.try_map_dtype(|dt: &DataType| {
313
let mut wrapped_dtype = dt.leaf_dtype().clone();
314
for dim in dims[1..].iter().rev() {
315
let Some(array_size) = dim.get() else {
316
polars_bail!(InvalidOperation: "can only infer the first dimension");
317
};
318
wrapped_dtype = DataType::Array(Box::new(wrapped_dtype), array_size as usize);
319
}
320
Ok(wrapped_dtype)
321
}),
322
#[cfg(feature = "cutqcut")]
323
QCut {
324
include_breaks: false,
325
..
326
} => mapper.with_dtype(DataType::from_categories(Categories::global())),
327
#[cfg(feature = "cutqcut")]
328
QCut {
329
include_breaks: true,
330
..
331
} => {
332
let struct_dt = DataType::Struct(vec![
333
Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64),
334
Field::new(
335
PlSmallStr::from_static("category"),
336
DataType::from_categories(Categories::global()),
337
),
338
]);
339
mapper.with_dtype(struct_dt)
340
},
341
#[cfg(feature = "rle")]
342
RLE => mapper.map_dtype(|dt| {
343
DataType::Struct(vec![
344
Field::new(PlSmallStr::from_static("len"), IDX_DTYPE),
345
Field::new(PlSmallStr::from_static("value"), dt.clone()),
346
])
347
}),
348
#[cfg(feature = "rle")]
349
RLEID => mapper.with_dtype(IDX_DTYPE),
350
ToPhysical => mapper.to_physical_type(),
351
#[cfg(feature = "random")]
352
Random { .. } => mapper.with_same_dtype(),
353
SetSortedFlag(_) => mapper.with_same_dtype(),
354
#[cfg(feature = "ffi_plugin")]
355
FfiPlugin {
356
flags: _,
357
lib,
358
symbol,
359
kwargs,
360
} => unsafe { plugin::plugin_field(fields, lib, symbol.as_ref(), kwargs) },
361
362
FoldHorizontal { return_dtype, .. } => match return_dtype {
363
None => mapper.with_same_dtype(),
364
Some(dtype) => mapper.with_dtype(dtype.clone()),
365
},
366
ReduceHorizontal { return_dtype, .. } => match return_dtype {
367
None => mapper.map_to_supertype(),
368
Some(dtype) => mapper.with_dtype(dtype.clone()),
369
},
370
#[cfg(feature = "dtype-struct")]
371
CumReduceHorizontal { return_dtype, .. } => match return_dtype {
372
None => mapper.with_dtype(DataType::Struct(fields.to_vec())),
373
Some(dtype) => mapper.with_dtype(DataType::Struct(
374
fields
375
.iter()
376
.map(|f| Field::new(f.name().clone(), dtype.clone()))
377
.collect(),
378
)),
379
},
380
#[cfg(feature = "dtype-struct")]
381
CumFoldHorizontal {
382
return_dtype,
383
include_init,
384
..
385
} => match return_dtype {
386
None => mapper.with_dtype(DataType::Struct(
387
fields
388
.iter()
389
.skip(usize::from(!include_init))
390
.map(|f| Field::new(f.name().clone(), fields[0].dtype().clone()))
391
.collect(),
392
)),
393
Some(dtype) => mapper.with_dtype(DataType::Struct(
394
fields
395
.iter()
396
.skip(usize::from(!include_init))
397
.map(|f| Field::new(f.name().clone(), dtype.clone()))
398
.collect(),
399
)),
400
},
401
402
MaxHorizontal => mapper.map_to_supertype(),
403
MinHorizontal => mapper.map_to_supertype(),
404
SumHorizontal { .. } => mapper.map_to_supertype().map(|mut f| {
405
if f.dtype == DataType::Boolean {
406
f.dtype = IDX_DTYPE;
407
}
408
f
409
}),
410
MeanHorizontal { .. } => mapper.map_to_supertype().map(|mut f| {
411
match f.dtype {
412
#[cfg(feature = "dtype-f16")]
413
DataType::Float16 => {},
414
DataType::Float32 => {},
415
_ => {
416
f.dtype = DataType::Float64;
417
},
418
}
419
f
420
}),
421
#[cfg(feature = "ewma")]
422
EwmMean { .. } => mapper.map_numeric_to_float_dtype(true),
423
#[cfg(feature = "ewma_by")]
424
EwmMeanBy { .. } => mapper.map_numeric_to_float_dtype(true),
425
#[cfg(feature = "ewma")]
426
EwmStd { .. } => mapper.map_numeric_to_float_dtype(true),
427
#[cfg(feature = "ewma")]
428
EwmVar { .. } => mapper.var_dtype(),
429
#[cfg(feature = "replace")]
430
Replace => mapper.with_same_dtype(),
431
#[cfg(feature = "replace")]
432
ReplaceStrict { return_dtype } => mapper.replace_dtype(return_dtype.clone()),
433
FillNullWithStrategy(_) => mapper.with_same_dtype(),
434
GatherEvery { .. } => mapper.with_same_dtype(),
435
#[cfg(feature = "reinterpret")]
436
Reinterpret(signed) => {
437
let dt = if *signed {
438
DataType::Int64
439
} else {
440
DataType::UInt64
441
};
442
mapper.with_dtype(dt)
443
},
444
ExtendConstant => mapper.with_same_dtype(),
445
446
RowEncode(..) => mapper.try_map_field(|_| {
447
Ok(Field::new(
448
PlSmallStr::from_static("row_encoded"),
449
DataType::BinaryOffset,
450
))
451
}),
452
#[cfg(feature = "dtype-struct")]
453
RowDecode(fields, _) => mapper.with_dtype(DataType::Struct(fields.to_vec())),
454
}
455
}
456
457
pub(crate) fn output_name(&self) -> Option<OutputName> {
458
match self {
459
#[cfg(feature = "dtype-struct")]
460
IRFunctionExpr::StructExpr(IRStructFunction::FieldByName(name)) => {
461
Some(OutputName::Field(name.clone()))
462
},
463
_ => None,
464
}
465
}
466
}
467
468
pub struct FieldsMapper<'a> {
469
fields: &'a [Field],
470
}
471
472
impl<'a> FieldsMapper<'a> {
473
pub fn new(fields: &'a [Field]) -> Self {
474
Self { fields }
475
}
476
477
pub fn args(&self) -> &[Field] {
478
self.fields
479
}
480
481
/// Field with the same dtype.
482
pub fn with_same_dtype(&self) -> PolarsResult<Field> {
483
self.map_dtype(|dtype| dtype.clone())
484
}
485
486
/// Set a dtype.
487
pub fn with_dtype(&self, dtype: DataType) -> PolarsResult<Field> {
488
Ok(Field::new(self.fields[0].name().clone(), dtype))
489
}
490
491
/// Map a single dtype.
492
pub fn map_dtype(&self, func: impl FnOnce(&DataType) -> DataType) -> PolarsResult<Field> {
493
let dtype = func(self.fields[0].dtype());
494
Ok(Field::new(self.fields[0].name().clone(), dtype))
495
}
496
497
pub fn get_fields_lens(&self) -> usize {
498
self.fields.len()
499
}
500
501
/// Map a single field with a potentially failing mapper function.
502
pub fn try_map_field(
503
&self,
504
func: impl FnOnce(&Field) -> PolarsResult<Field>,
505
) -> PolarsResult<Field> {
506
func(&self.fields[0])
507
}
508
509
pub fn var_dtype(&self) -> PolarsResult<Field> {
510
if self.fields[0].dtype().leaf_dtype().is_duration() {
511
let map_inner = |dt: &DataType| match dt {
512
dt if dt.is_temporal() => {
513
polars_bail!(InvalidOperation: "operation `var` is not supported for `{dt}`")
514
},
515
dt => Ok(dt.clone()),
516
};
517
518
self.try_map_dtype(|dt| match dt {
519
#[cfg(feature = "dtype-array")]
520
DataType::Array(inner, _) => map_inner(inner),
521
DataType::List(inner) => map_inner(inner),
522
_ => map_inner(dt),
523
})
524
} else {
525
self.moment_dtype()
526
}
527
}
528
529
pub fn moment_dtype(&self) -> PolarsResult<Field> {
530
let map_inner = |dt: &DataType| match dt {
531
DataType::Boolean => DataType::Float64,
532
#[cfg(feature = "dtype-f16")]
533
DataType::Float16 => DataType::Float16,
534
DataType::Float32 => DataType::Float32,
535
DataType::Float64 => DataType::Float64,
536
dt if dt.is_primitive_numeric() => DataType::Float64,
537
#[cfg(feature = "dtype-date")]
538
DataType::Date => DataType::Datetime(TimeUnit::Microseconds, None),
539
#[cfg(feature = "dtype-datetime")]
540
dt @ DataType::Datetime(_, _) => dt.clone(),
541
#[cfg(feature = "dtype-duration")]
542
dt @ DataType::Duration(_) => dt.clone(),
543
#[cfg(feature = "dtype-time")]
544
dt @ DataType::Time => dt.clone(),
545
#[cfg(feature = "dtype-decimal")]
546
DataType::Decimal(..) => DataType::Float64,
547
548
// All other types get mapped to a single `null` of the same type.
549
dt => dt.clone(),
550
};
551
552
self.map_dtype(|dt| match dt {
553
#[cfg(feature = "dtype-array")]
554
DataType::Array(inner, _) => map_inner(inner),
555
DataType::List(inner) => map_inner(inner),
556
_ => map_inner(dt),
557
})
558
}
559
560
/// Map to a float supertype.
561
pub fn map_to_float_dtype(&self) -> PolarsResult<Field> {
562
self.map_dtype(|dtype| match dtype {
563
#[cfg(feature = "dtype-f16")]
564
DataType::Float16 => DataType::Float16,
565
DataType::Float32 => DataType::Float32,
566
_ => DataType::Float64,
567
})
568
}
569
570
/// Map to a float supertype if numeric, else preserve
571
pub fn map_numeric_to_float_dtype(&self, coerce_decimal: bool) -> PolarsResult<Field> {
572
self.map_dtype(|dt| {
573
let should_coerce = match dt {
574
#[cfg(feature = "dtype-f16")]
575
DataType::Float16 => false,
576
DataType::Float32 => false,
577
#[cfg(feature = "dtype-decimal")]
578
DataType::Decimal(..) => coerce_decimal,
579
DataType::Boolean => true,
580
dt => dt.is_primitive_numeric(),
581
};
582
583
if should_coerce {
584
DataType::Float64
585
} else {
586
dt.clone()
587
}
588
})
589
}
590
591
/// Map to a physical type.
592
pub fn to_physical_type(&self) -> PolarsResult<Field> {
593
self.map_dtype(|dtype| dtype.to_physical())
594
}
595
596
/// Map a single dtype with a potentially failing mapper function.
597
pub fn try_map_dtype(
598
&self,
599
func: impl FnOnce(&DataType) -> PolarsResult<DataType>,
600
) -> PolarsResult<Field> {
601
let dtype = func(self.fields[0].dtype())?;
602
Ok(Field::new(self.fields[0].name().clone(), dtype))
603
}
604
605
/// Map all dtypes with a potentially failing mapper function.
606
pub fn try_map_dtypes(
607
&self,
608
func: impl FnOnce(&[&DataType]) -> PolarsResult<DataType>,
609
) -> PolarsResult<Field> {
610
let mut fld = self.fields[0].clone();
611
let dtypes = self
612
.fields
613
.iter()
614
.map(|fld| fld.dtype())
615
.collect::<Vec<_>>();
616
let new_type = func(&dtypes)?;
617
fld.coerce(new_type);
618
Ok(fld)
619
}
620
621
/// Map the dtype to the "supertype" of all fields.
622
pub fn map_to_supertype(&self) -> PolarsResult<Field> {
623
let st = args_to_supertype(self.fields)?;
624
let mut first = self.fields[0].clone();
625
first.coerce(st);
626
Ok(first)
627
}
628
629
/// Map the dtype to the dtype of the list/array elements.
630
pub fn map_to_list_and_array_inner_dtype(&self) -> PolarsResult<Field> {
631
let mut first = self.fields[0].clone();
632
let dt = first
633
.dtype()
634
.inner_dtype()
635
.cloned()
636
.unwrap_or_else(|| DataType::Unknown(Default::default()));
637
first.coerce(dt);
638
Ok(first)
639
}
640
641
#[cfg(feature = "dtype-array")]
642
/// Map the dtype to the dtype of the array elements, with typo validation.
643
pub fn try_map_to_array_inner_dtype(&self) -> PolarsResult<Field> {
644
let dt = self.fields[0].dtype();
645
match dt {
646
DataType::Array(_, _) => self.map_to_list_and_array_inner_dtype(),
647
_ => polars_bail!(InvalidOperation: "expected Array type, got: {}", dt),
648
}
649
}
650
651
/// Map the dtypes to the "supertype" of a list of lists.
652
pub fn map_to_list_supertype(&self) -> PolarsResult<Field> {
653
self.try_map_dtypes(|dts| {
654
let mut super_type_inner = None;
655
656
for dt in dts {
657
match dt {
658
DataType::List(inner) => match super_type_inner {
659
None => super_type_inner = Some(*inner.clone()),
660
Some(st_inner) => {
661
super_type_inner = Some(try_get_supertype(&st_inner, inner)?)
662
},
663
},
664
dt => match super_type_inner {
665
None => super_type_inner = Some((*dt).clone()),
666
Some(st_inner) => {
667
super_type_inner = Some(try_get_supertype(&st_inner, dt)?)
668
},
669
},
670
}
671
}
672
Ok(DataType::List(Box::new(super_type_inner.unwrap())))
673
})
674
}
675
676
/// Set the timezone of a datetime dtype.
677
#[cfg(feature = "timezones")]
678
pub fn map_datetime_dtype_timezone(&self, tz: Option<&TimeZone>) -> PolarsResult<Field> {
679
self.try_map_dtype(|dt| {
680
if let DataType::Datetime(tu, _) = dt {
681
Ok(DataType::Datetime(*tu, tz.cloned()))
682
} else {
683
polars_bail!(op = "replace-time-zone", got = dt, expected = "Datetime");
684
}
685
})
686
}
687
688
pub fn sum_dtype(&self) -> PolarsResult<Field> {
689
use DataType::*;
690
self.map_dtype(|dtype| match dtype {
691
Int8 | UInt8 | Int16 | UInt16 => Int64,
692
Boolean => IDX_DTYPE,
693
dt => dt.clone(),
694
})
695
}
696
697
pub fn nested_sum_type(&self) -> PolarsResult<Field> {
698
let mut first = self.fields[0].clone();
699
use DataType::*;
700
let dt = first.dtype().inner_dtype().cloned().ok_or_else(|| {
701
polars_err!(
702
InvalidOperation:"expected List or Array type, got dtype: {}",
703
first.dtype()
704
)
705
})?;
706
707
match dt {
708
Boolean => first.coerce(IDX_DTYPE),
709
UInt8 | Int8 | Int16 | UInt16 => first.coerce(Int64),
710
_ => first.coerce(dt),
711
}
712
Ok(first)
713
}
714
715
pub fn nested_mean_median_type(&self) -> PolarsResult<Field> {
716
let mut first = self.fields[0].clone();
717
use DataType::*;
718
let dt = first.dtype().inner_dtype().cloned().ok_or_else(|| {
719
polars_err!(
720
InvalidOperation:"expected List or Array type, got dtype: {}",
721
first.dtype()
722
)
723
})?;
724
725
let new_dt = match dt {
726
#[cfg(feature = "dtype-datetime")]
727
Date => Datetime(TimeUnit::Microseconds, None),
728
dt if dt.is_temporal() => dt,
729
#[cfg(feature = "dtype-f16")]
730
Float16 => Float16,
731
Float32 => Float32,
732
_ => Float64,
733
};
734
first.coerce(new_dt);
735
Ok(first)
736
}
737
738
pub(super) fn pow_dtype(&self) -> PolarsResult<Field> {
739
let dtype1 = self.fields[0].dtype();
740
let dtype2 = self.fields[1].dtype();
741
let out_dtype = if dtype1.is_integer() {
742
if dtype2.is_float() { dtype2 } else { dtype1 }
743
} else {
744
dtype1
745
};
746
Ok(Field::new(self.fields[0].name().clone(), out_dtype.clone()))
747
}
748
749
pub(super) fn log_dtype(&self) -> PolarsResult<Field> {
750
let dtype1 = self.fields[0].dtype();
751
let dtype2 = self.fields[1].dtype();
752
let out_dtype = if dtype1.is_float() {
753
dtype1
754
} else if dtype2.is_float() {
755
dtype2
756
} else {
757
&DataType::Float64
758
};
759
Ok(Field::new(self.fields[0].name().clone(), out_dtype.clone()))
760
}
761
762
#[cfg(feature = "extract_jsonpath")]
763
pub fn with_opt_dtype(&self, dtype: Option<DataType>) -> PolarsResult<Field> {
764
let dtype = dtype.unwrap_or_else(|| DataType::Unknown(Default::default()));
765
self.with_dtype(dtype)
766
}
767
768
#[cfg(feature = "replace")]
769
pub fn replace_dtype(&self, return_dtype: Option<DataType>) -> PolarsResult<Field> {
770
let dtype = match return_dtype {
771
Some(dtype) => dtype,
772
None => {
773
let new = &self.fields[2];
774
let default = self.fields.get(3);
775
776
// @HACK: Related to implicit implode see #22149.
777
let inner_dtype = new.dtype().inner_dtype().unwrap_or(new.dtype());
778
779
match default {
780
Some(default) => try_get_supertype(default.dtype(), inner_dtype)?,
781
None => inner_dtype.clone(),
782
}
783
},
784
};
785
self.with_dtype(dtype)
786
}
787
788
fn ensure_satisfies(
789
self,
790
mut f: impl FnMut(usize, &DataType) -> bool,
791
op: &'static str,
792
) -> PolarsResult<Self> {
793
for (i, field) in self.fields.iter().enumerate() {
794
polars_ensure!(
795
f(i, field.dtype()),
796
opidx = op,
797
idx = i,
798
self.fields[i].dtype()
799
);
800
}
801
802
Ok(self)
803
}
804
}
805
806
pub(crate) fn args_to_supertype<D: AsRef<DataType>>(dtypes: &[D]) -> PolarsResult<DataType> {
807
let mut st = dtypes[0].as_ref().clone();
808
for dt in &dtypes[1..] {
809
st = try_get_supertype(&st, dt.as_ref())?
810
}
811
812
match (dtypes[0].as_ref(), &st) {
813
#[cfg(feature = "dtype-categorical")]
814
(cat @ DataType::Categorical(_, _), DataType::String) => st = cat.clone(),
815
_ => {
816
if let DataType::Unknown(kind) = st {
817
match kind {
818
UnknownKind::Float => st = DataType::Float64,
819
UnknownKind::Int(v) => {
820
st = materialize_dyn_int(v).dtype();
821
},
822
UnknownKind::Str => st = DataType::String,
823
_ => {},
824
}
825
}
826
},
827
}
828
829
Ok(st)
830
}
831
832