Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/aexpr/function_expr/schema.rs
7889 views
1
use polars_core::utils::materialize_dyn_int;
2
3
use super::*;
4
5
impl IRFunctionExpr {
6
pub(crate) fn get_field(
7
&self,
8
_input_schema: &Schema,
9
fields: &[Field],
10
) -> PolarsResult<Field> {
11
use IRFunctionExpr::*;
12
13
let mapper = FieldsMapper { fields };
14
match self {
15
// Namespaces
16
#[cfg(feature = "dtype-array")]
17
ArrayExpr(func) => func.get_field(mapper),
18
BinaryExpr(s) => s.get_field(mapper),
19
#[cfg(feature = "dtype-categorical")]
20
Categorical(func) => func.get_field(mapper),
21
#[cfg(feature = "dtype-extension")]
22
Extension(func) => func.get_field(mapper),
23
ListExpr(func) => func.get_field(mapper),
24
#[cfg(feature = "strings")]
25
StringExpr(s) => s.get_field(mapper),
26
#[cfg(feature = "dtype-struct")]
27
StructExpr(s) => s.get_field(mapper),
28
#[cfg(feature = "temporal")]
29
TemporalExpr(fun) => fun.get_field(mapper),
30
#[cfg(feature = "bitwise")]
31
Bitwise(fun) => fun.get_field(mapper),
32
33
// Other expressions
34
Boolean(func) => func.get_field(mapper),
35
#[cfg(feature = "business")]
36
Business(func) => func.get_field(mapper),
37
#[cfg(feature = "abs")]
38
Abs => mapper.with_same_dtype(),
39
Negate => mapper.with_same_dtype(),
40
NullCount => mapper.with_dtype(IDX_DTYPE),
41
Pow(pow_function) => match pow_function {
42
IRPowFunction::Generic => mapper.pow_dtype(),
43
_ => mapper.map_numeric_to_float_dtype(true),
44
},
45
Coalesce => mapper.map_to_supertype(),
46
#[cfg(feature = "row_hash")]
47
Hash(..) => mapper.with_dtype(DataType::UInt64),
48
#[cfg(feature = "arg_where")]
49
ArgWhere => mapper.with_dtype(IDX_DTYPE),
50
#[cfg(feature = "index_of")]
51
IndexOf => mapper.with_dtype(IDX_DTYPE),
52
#[cfg(feature = "search_sorted")]
53
SearchSorted { .. } => mapper.with_dtype(IDX_DTYPE),
54
#[cfg(feature = "range")]
55
Range(func) => func.get_field(mapper),
56
#[cfg(feature = "trigonometry")]
57
Trigonometry(_) => mapper.map_to_float_dtype(),
58
#[cfg(feature = "trigonometry")]
59
Atan2 => mapper.map_to_float_dtype(),
60
#[cfg(feature = "sign")]
61
Sign => mapper
62
.ensure_satisfies(|_, dtype| dtype.is_numeric(), "sign")?
63
.with_same_dtype(),
64
FillNull => mapper.map_to_supertype(),
65
#[cfg(feature = "rolling_window")]
66
RollingExpr { function, options } => {
67
use IRRollingFunction::*;
68
match function {
69
Min | Max => mapper.with_same_dtype(),
70
Mean | Quantile | Std => mapper.moment_dtype(),
71
Var => mapper.var_dtype(),
72
Sum => mapper.sum_dtype(),
73
Rank => match options.fn_params {
74
Some(RollingFnParams::Rank {
75
method: RollingRankMethod::Average,
76
..
77
}) => mapper.with_dtype(DataType::Float64),
78
Some(RollingFnParams::Rank { .. }) => mapper.with_dtype(IDX_DTYPE),
79
_ => unreachable!("should be Some(RollingFnParams::Rank)"),
80
},
81
#[cfg(feature = "cov")]
82
CorrCov { .. } => mapper.map_to_float_dtype(),
83
#[cfg(feature = "moment")]
84
Skew | Kurtosis => mapper.map_to_float_dtype(),
85
Map(_) => mapper.try_map_field(|field| {
86
if options.weights.is_some() {
87
let dtype = match field.dtype() {
88
#[cfg(feature = "dtype-f16")]
89
DataType::Float16 => DataType::Float16,
90
DataType::Float32 => DataType::Float32,
91
_ => DataType::Float64,
92
};
93
Ok(Field::new(field.name().clone(), dtype))
94
} else {
95
Ok(field.clone())
96
}
97
}),
98
}
99
},
100
#[cfg(feature = "rolling_window_by")]
101
RollingExprBy {
102
function_by,
103
options,
104
..
105
} => {
106
use IRRollingFunctionBy::*;
107
match function_by {
108
MinBy | MaxBy => mapper.with_same_dtype(),
109
MeanBy | QuantileBy | StdBy => mapper.moment_dtype(),
110
VarBy => mapper.var_dtype(),
111
SumBy => mapper.sum_dtype(),
112
RankBy => match options.fn_params {
113
Some(RollingFnParams::Rank {
114
method: RollingRankMethod::Average,
115
..
116
}) => mapper.with_dtype(DataType::Float64),
117
Some(RollingFnParams::Rank { .. }) => mapper.with_dtype(IDX_DTYPE),
118
_ => unreachable!("should be Some(RollingFnParams::Rank)"),
119
},
120
}
121
},
122
Rechunk => mapper.with_same_dtype(),
123
Append { upcast } => {
124
if *upcast {
125
mapper.map_to_supertype()
126
} else {
127
mapper.with_same_dtype()
128
}
129
},
130
ShiftAndFill => mapper.with_same_dtype(),
131
DropNans => mapper.with_same_dtype(),
132
DropNulls => mapper.with_same_dtype(),
133
#[cfg(feature = "round_series")]
134
Clip {
135
has_min: _,
136
has_max: _,
137
} => mapper.with_same_dtype(),
138
#[cfg(feature = "mode")]
139
Mode { maintain_order: _ } => mapper.with_same_dtype(),
140
#[cfg(feature = "moment")]
141
Skew(_) => mapper.with_dtype(DataType::Float64),
142
#[cfg(feature = "moment")]
143
Kurtosis(..) => mapper.with_dtype(DataType::Float64),
144
ArgUnique | ArgMin | ArgMax | ArgSort { .. } => mapper.with_dtype(IDX_DTYPE),
145
Product => mapper.map_dtype(|dtype| {
146
use DataType as T;
147
match dtype {
148
#[cfg(feature = "dtype-f16")]
149
T::Float16 => T::Float16,
150
T::Float32 => T::Float32,
151
T::Float64 => T::Float64,
152
T::UInt64 => T::UInt64,
153
#[cfg(feature = "dtype-i128")]
154
T::Int128 => T::Int128,
155
_ => T::Int64,
156
}
157
}),
158
Repeat => mapper.with_same_dtype(),
159
#[cfg(feature = "rank")]
160
Rank { options, .. } => mapper.with_dtype(match options.method {
161
RankMethod::Average => DataType::Float64,
162
_ => IDX_DTYPE,
163
}),
164
#[cfg(feature = "dtype-struct")]
165
AsStruct => {
166
let mut field_names = PlHashSet::with_capacity(fields.len() - 1);
167
let struct_fields = fields
168
.iter()
169
.map(|f| {
170
polars_ensure!(
171
field_names.insert(f.name.as_str()),
172
duplicate_field = f.name()
173
);
174
Ok(f.clone())
175
})
176
.collect::<PolarsResult<Vec<_>>>()?;
177
Ok(Field::new(
178
fields[0].name().clone(),
179
DataType::Struct(struct_fields),
180
))
181
},
182
#[cfg(feature = "top_k")]
183
TopK { .. } => mapper.with_same_dtype(),
184
#[cfg(feature = "top_k")]
185
TopKBy { .. } => mapper.with_same_dtype(),
186
#[cfg(feature = "dtype-struct")]
187
ValueCounts {
188
sort: _,
189
parallel: _,
190
name,
191
normalize,
192
} => mapper.map_dtype(|dt| {
193
let count_dt = if *normalize {
194
DataType::Float64
195
} else {
196
IDX_DTYPE
197
};
198
DataType::Struct(vec![
199
Field::new(fields[0].name().clone(), dt.clone()),
200
Field::new(name.clone(), count_dt),
201
])
202
}),
203
#[cfg(feature = "unique_counts")]
204
UniqueCounts => mapper.with_dtype(IDX_DTYPE),
205
Shift | Reverse => mapper.with_same_dtype(),
206
#[cfg(feature = "cum_agg")]
207
CumCount { .. } => mapper.with_dtype(IDX_DTYPE),
208
#[cfg(feature = "cum_agg")]
209
CumSum { .. } => mapper.map_dtype(cum::dtypes::cum_sum),
210
#[cfg(feature = "cum_agg")]
211
CumProd { .. } => mapper.map_dtype(cum::dtypes::cum_prod),
212
#[cfg(feature = "cum_agg")]
213
CumMin { .. } => mapper.with_same_dtype(),
214
#[cfg(feature = "cum_agg")]
215
CumMax { .. } => mapper.with_same_dtype(),
216
#[cfg(feature = "approx_unique")]
217
ApproxNUnique => mapper.with_dtype(IDX_DTYPE),
218
#[cfg(feature = "hist")]
219
Hist {
220
include_category,
221
include_breakpoint,
222
..
223
} => {
224
if *include_breakpoint || *include_category {
225
let mut fields = Vec::with_capacity(3);
226
if *include_breakpoint {
227
fields.push(Field::new(
228
PlSmallStr::from_static("breakpoint"),
229
DataType::Float64,
230
));
231
}
232
if *include_category {
233
fields.push(Field::new(
234
PlSmallStr::from_static("category"),
235
DataType::from_categories(Categories::global()),
236
));
237
}
238
fields.push(Field::new(PlSmallStr::from_static("count"), IDX_DTYPE));
239
mapper.with_dtype(DataType::Struct(fields))
240
} else {
241
mapper.with_dtype(IDX_DTYPE)
242
}
243
},
244
#[cfg(feature = "diff")]
245
Diff(_) => mapper.map_dtype(|dt| match dt {
246
#[cfg(feature = "dtype-datetime")]
247
DataType::Datetime(tu, _) => DataType::Duration(*tu),
248
#[cfg(feature = "dtype-date")]
249
DataType::Date => DataType::Duration(TimeUnit::Microseconds),
250
#[cfg(feature = "dtype-time")]
251
DataType::Time => DataType::Duration(TimeUnit::Nanoseconds),
252
DataType::UInt64 | DataType::UInt32 => DataType::Int64,
253
DataType::UInt16 => DataType::Int32,
254
DataType::UInt8 => DataType::Int16,
255
dt => dt.clone(),
256
}),
257
#[cfg(feature = "pct_change")]
258
PctChange => mapper.map_dtype(|dt| match dt {
259
#[cfg(feature = "dtype-f16")]
260
DataType::Float16 => dt.clone(),
261
DataType::Float32 => dt.clone(),
262
_ => DataType::Float64,
263
}),
264
#[cfg(feature = "interpolate")]
265
Interpolate(method) => match method {
266
InterpolationMethod::Linear => mapper.map_numeric_to_float_dtype(false),
267
InterpolationMethod::Nearest => mapper.with_same_dtype(),
268
},
269
#[cfg(feature = "interpolate_by")]
270
InterpolateBy => mapper.map_numeric_to_float_dtype(true),
271
#[cfg(feature = "log")]
272
Entropy { .. } | Log1p | Exp => mapper.map_to_float_dtype(),
273
#[cfg(feature = "log")]
274
Log => mapper.log_dtype(),
275
Unique(_) => mapper.with_same_dtype(),
276
#[cfg(feature = "round_series")]
277
Round { .. } | RoundSF { .. } | Floor | Ceil => mapper.with_same_dtype(),
278
#[cfg(feature = "fused")]
279
Fused(_) => mapper.map_to_supertype(),
280
ConcatExpr(_) => mapper.map_to_supertype(),
281
#[cfg(feature = "cov")]
282
Correlation { .. } => mapper.map_to_float_dtype(),
283
#[cfg(feature = "peaks")]
284
PeakMin | PeakMax => mapper.with_dtype(DataType::Boolean),
285
#[cfg(feature = "cutqcut")]
286
Cut {
287
include_breaks: false,
288
..
289
} => mapper.with_dtype(DataType::from_categories(Categories::global())),
290
#[cfg(feature = "cutqcut")]
291
Cut {
292
include_breaks: true,
293
..
294
} => {
295
let struct_dt = DataType::Struct(vec![
296
Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64),
297
Field::new(
298
PlSmallStr::from_static("category"),
299
DataType::from_categories(Categories::global()),
300
),
301
]);
302
mapper.with_dtype(struct_dt)
303
},
304
#[cfg(feature = "repeat_by")]
305
RepeatBy => mapper.map_dtype(|dt| DataType::List(dt.clone().into())),
306
#[cfg(feature = "dtype-array")]
307
Reshape(dims) => mapper.try_map_dtype(|dt: &DataType| {
308
let mut wrapped_dtype = dt.leaf_dtype().clone();
309
for dim in dims[1..].iter().rev() {
310
let Some(array_size) = dim.get() else {
311
polars_bail!(InvalidOperation: "can only infer the first dimension");
312
};
313
wrapped_dtype = DataType::Array(Box::new(wrapped_dtype), array_size as usize);
314
}
315
Ok(wrapped_dtype)
316
}),
317
#[cfg(feature = "cutqcut")]
318
QCut {
319
include_breaks: false,
320
..
321
} => mapper.with_dtype(DataType::from_categories(Categories::global())),
322
#[cfg(feature = "cutqcut")]
323
QCut {
324
include_breaks: true,
325
..
326
} => {
327
let struct_dt = DataType::Struct(vec![
328
Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64),
329
Field::new(
330
PlSmallStr::from_static("category"),
331
DataType::from_categories(Categories::global()),
332
),
333
]);
334
mapper.with_dtype(struct_dt)
335
},
336
#[cfg(feature = "rle")]
337
RLE => mapper.map_dtype(|dt| {
338
DataType::Struct(vec![
339
Field::new(PlSmallStr::from_static("len"), IDX_DTYPE),
340
Field::new(PlSmallStr::from_static("value"), dt.clone()),
341
])
342
}),
343
#[cfg(feature = "rle")]
344
RLEID => mapper.with_dtype(IDX_DTYPE),
345
ToPhysical => mapper.to_physical_type(),
346
#[cfg(feature = "random")]
347
Random { .. } => mapper.with_same_dtype(),
348
SetSortedFlag(_) => mapper.with_same_dtype(),
349
#[cfg(feature = "ffi_plugin")]
350
FfiPlugin {
351
flags: _,
352
lib,
353
symbol,
354
kwargs,
355
} => unsafe { plugin::plugin_field(fields, lib, symbol.as_ref(), kwargs) },
356
357
FoldHorizontal { return_dtype, .. } => match return_dtype {
358
None => mapper.with_same_dtype(),
359
Some(dtype) => mapper.with_dtype(dtype.clone()),
360
},
361
ReduceHorizontal { return_dtype, .. } => match return_dtype {
362
None => mapper.map_to_supertype(),
363
Some(dtype) => mapper.with_dtype(dtype.clone()),
364
},
365
#[cfg(feature = "dtype-struct")]
366
CumReduceHorizontal { return_dtype, .. } => match return_dtype {
367
None => mapper.with_dtype(DataType::Struct(fields.to_vec())),
368
Some(dtype) => mapper.with_dtype(DataType::Struct(
369
fields
370
.iter()
371
.map(|f| Field::new(f.name().clone(), dtype.clone()))
372
.collect(),
373
)),
374
},
375
#[cfg(feature = "dtype-struct")]
376
CumFoldHorizontal {
377
return_dtype,
378
include_init,
379
..
380
} => match return_dtype {
381
None => mapper.with_dtype(DataType::Struct(
382
fields
383
.iter()
384
.skip(usize::from(!include_init))
385
.map(|f| Field::new(f.name().clone(), fields[0].dtype().clone()))
386
.collect(),
387
)),
388
Some(dtype) => mapper.with_dtype(DataType::Struct(
389
fields
390
.iter()
391
.skip(usize::from(!include_init))
392
.map(|f| Field::new(f.name().clone(), dtype.clone()))
393
.collect(),
394
)),
395
},
396
397
MaxHorizontal => mapper.map_to_supertype(),
398
MinHorizontal => mapper.map_to_supertype(),
399
SumHorizontal { .. } => mapper.map_to_supertype().map(|mut f| {
400
if f.dtype == DataType::Boolean {
401
f.dtype = IDX_DTYPE;
402
}
403
f
404
}),
405
MeanHorizontal { .. } => mapper.map_to_supertype().map(|mut f| {
406
match f.dtype {
407
#[cfg(feature = "dtype-f16")]
408
DataType::Float16 => {},
409
DataType::Float32 => {},
410
_ => {
411
f.dtype = DataType::Float64;
412
},
413
}
414
f
415
}),
416
#[cfg(feature = "ewma")]
417
EwmMean { .. } => mapper.map_numeric_to_float_dtype(true),
418
#[cfg(feature = "ewma_by")]
419
EwmMeanBy { .. } => mapper.map_numeric_to_float_dtype(true),
420
#[cfg(feature = "ewma")]
421
EwmStd { .. } => mapper.map_numeric_to_float_dtype(true),
422
#[cfg(feature = "ewma")]
423
EwmVar { .. } => mapper.var_dtype(),
424
#[cfg(feature = "replace")]
425
Replace => mapper.with_same_dtype(),
426
#[cfg(feature = "replace")]
427
ReplaceStrict { return_dtype } => mapper.replace_dtype(return_dtype.clone()),
428
FillNullWithStrategy(_) => mapper.with_same_dtype(),
429
GatherEvery { .. } => mapper.with_same_dtype(),
430
#[cfg(feature = "reinterpret")]
431
Reinterpret(signed) => {
432
let dt = if *signed {
433
DataType::Int64
434
} else {
435
DataType::UInt64
436
};
437
mapper.with_dtype(dt)
438
},
439
ExtendConstant => mapper.with_same_dtype(),
440
441
RowEncode(..) => mapper.try_map_field(|_| {
442
Ok(Field::new(
443
PlSmallStr::from_static("row_encoded"),
444
DataType::BinaryOffset,
445
))
446
}),
447
#[cfg(feature = "dtype-struct")]
448
RowDecode(fields, _) => mapper.with_dtype(DataType::Struct(fields.to_vec())),
449
}
450
}
451
452
pub(crate) fn output_name(&self) -> Option<OutputName> {
453
match self {
454
#[cfg(feature = "dtype-struct")]
455
IRFunctionExpr::StructExpr(IRStructFunction::FieldByName(name)) => {
456
Some(OutputName::Field(name.clone()))
457
},
458
_ => None,
459
}
460
}
461
}
462
463
pub struct FieldsMapper<'a> {
464
fields: &'a [Field],
465
}
466
467
impl<'a> FieldsMapper<'a> {
468
pub fn new(fields: &'a [Field]) -> Self {
469
Self { fields }
470
}
471
472
pub fn args(&self) -> &[Field] {
473
self.fields
474
}
475
476
/// Field with the same dtype.
477
pub fn with_same_dtype(&self) -> PolarsResult<Field> {
478
self.map_dtype(|dtype| dtype.clone())
479
}
480
481
/// Set a dtype.
482
pub fn with_dtype(&self, dtype: DataType) -> PolarsResult<Field> {
483
Ok(Field::new(self.fields[0].name().clone(), dtype))
484
}
485
486
/// Map a single dtype.
487
pub fn map_dtype(&self, func: impl FnOnce(&DataType) -> DataType) -> PolarsResult<Field> {
488
let dtype = func(self.fields[0].dtype());
489
Ok(Field::new(self.fields[0].name().clone(), dtype))
490
}
491
492
pub fn get_fields_lens(&self) -> usize {
493
self.fields.len()
494
}
495
496
/// Map a single field with a potentially failing mapper function.
497
pub fn try_map_field(
498
&self,
499
func: impl FnOnce(&Field) -> PolarsResult<Field>,
500
) -> PolarsResult<Field> {
501
func(&self.fields[0])
502
}
503
504
pub fn var_dtype(&self) -> PolarsResult<Field> {
505
if self.fields[0].dtype().leaf_dtype().is_duration() {
506
let map_inner = |dt: &DataType| match dt {
507
dt if dt.is_temporal() => {
508
polars_bail!(InvalidOperation: "operation `var` is not supported for `{dt}`")
509
},
510
dt => Ok(dt.clone()),
511
};
512
513
self.try_map_dtype(|dt| match dt {
514
#[cfg(feature = "dtype-array")]
515
DataType::Array(inner, _) => map_inner(inner),
516
DataType::List(inner) => map_inner(inner),
517
_ => map_inner(dt),
518
})
519
} else {
520
self.moment_dtype()
521
}
522
}
523
524
pub fn moment_dtype(&self) -> PolarsResult<Field> {
525
let map_inner = |dt: &DataType| match dt {
526
DataType::Boolean => DataType::Float64,
527
#[cfg(feature = "dtype-f16")]
528
DataType::Float16 => DataType::Float16,
529
DataType::Float32 => DataType::Float32,
530
DataType::Float64 => DataType::Float64,
531
dt if dt.is_primitive_numeric() => DataType::Float64,
532
#[cfg(feature = "dtype-date")]
533
DataType::Date => DataType::Datetime(TimeUnit::Microseconds, None),
534
#[cfg(feature = "dtype-datetime")]
535
dt @ DataType::Datetime(_, _) => dt.clone(),
536
#[cfg(feature = "dtype-duration")]
537
dt @ DataType::Duration(_) => dt.clone(),
538
#[cfg(feature = "dtype-time")]
539
dt @ DataType::Time => dt.clone(),
540
#[cfg(feature = "dtype-decimal")]
541
DataType::Decimal(..) => DataType::Float64,
542
543
// All other types get mapped to a single `null` of the same type.
544
dt => dt.clone(),
545
};
546
547
self.map_dtype(|dt| match dt {
548
#[cfg(feature = "dtype-array")]
549
DataType::Array(inner, _) => map_inner(inner),
550
DataType::List(inner) => map_inner(inner),
551
_ => map_inner(dt),
552
})
553
}
554
555
/// Map to a float supertype.
556
pub fn map_to_float_dtype(&self) -> PolarsResult<Field> {
557
self.map_dtype(|dtype| match dtype {
558
#[cfg(feature = "dtype-f16")]
559
DataType::Float16 => DataType::Float16,
560
DataType::Float32 => DataType::Float32,
561
_ => DataType::Float64,
562
})
563
}
564
565
/// Map to a float supertype if numeric, else preserve
566
pub fn map_numeric_to_float_dtype(&self, coerce_decimal: bool) -> PolarsResult<Field> {
567
self.map_dtype(|dt| {
568
let should_coerce = match dt {
569
#[cfg(feature = "dtype-f16")]
570
DataType::Float16 => false,
571
DataType::Float32 => false,
572
#[cfg(feature = "dtype-decimal")]
573
DataType::Decimal(..) => coerce_decimal,
574
DataType::Boolean => true,
575
dt => dt.is_primitive_numeric(),
576
};
577
578
if should_coerce {
579
DataType::Float64
580
} else {
581
dt.clone()
582
}
583
})
584
}
585
586
/// Map to a physical type.
587
pub fn to_physical_type(&self) -> PolarsResult<Field> {
588
self.map_dtype(|dtype| dtype.to_physical())
589
}
590
591
/// Map a single dtype with a potentially failing mapper function.
592
pub fn try_map_dtype(
593
&self,
594
func: impl FnOnce(&DataType) -> PolarsResult<DataType>,
595
) -> PolarsResult<Field> {
596
let dtype = func(self.fields[0].dtype())?;
597
Ok(Field::new(self.fields[0].name().clone(), dtype))
598
}
599
600
/// Map all dtypes with a potentially failing mapper function.
601
pub fn try_map_dtypes(
602
&self,
603
func: impl FnOnce(&[&DataType]) -> PolarsResult<DataType>,
604
) -> PolarsResult<Field> {
605
let mut fld = self.fields[0].clone();
606
let dtypes = self
607
.fields
608
.iter()
609
.map(|fld| fld.dtype())
610
.collect::<Vec<_>>();
611
let new_type = func(&dtypes)?;
612
fld.coerce(new_type);
613
Ok(fld)
614
}
615
616
/// Map the dtype to the "supertype" of all fields.
617
pub fn map_to_supertype(&self) -> PolarsResult<Field> {
618
let st = args_to_supertype(self.fields)?;
619
let mut first = self.fields[0].clone();
620
first.coerce(st);
621
Ok(first)
622
}
623
624
/// Map the dtype to the dtype of the list/array elements.
625
pub fn map_to_list_and_array_inner_dtype(&self) -> PolarsResult<Field> {
626
let mut first = self.fields[0].clone();
627
let dt = first
628
.dtype()
629
.inner_dtype()
630
.cloned()
631
.unwrap_or_else(|| DataType::Unknown(Default::default()));
632
first.coerce(dt);
633
Ok(first)
634
}
635
636
#[cfg(feature = "dtype-array")]
637
/// Map the dtype to the dtype of the array elements, with typo validation.
638
pub fn try_map_to_array_inner_dtype(&self) -> PolarsResult<Field> {
639
let dt = self.fields[0].dtype();
640
match dt {
641
DataType::Array(_, _) => self.map_to_list_and_array_inner_dtype(),
642
_ => polars_bail!(InvalidOperation: "expected Array type, got: {}", dt),
643
}
644
}
645
646
/// Map the dtypes to the "supertype" of a list of lists.
647
pub fn map_to_list_supertype(&self) -> PolarsResult<Field> {
648
self.try_map_dtypes(|dts| {
649
let mut super_type_inner = None;
650
651
for dt in dts {
652
match dt {
653
DataType::List(inner) => match super_type_inner {
654
None => super_type_inner = Some(*inner.clone()),
655
Some(st_inner) => {
656
super_type_inner = Some(try_get_supertype(&st_inner, inner)?)
657
},
658
},
659
dt => match super_type_inner {
660
None => super_type_inner = Some((*dt).clone()),
661
Some(st_inner) => {
662
super_type_inner = Some(try_get_supertype(&st_inner, dt)?)
663
},
664
},
665
}
666
}
667
Ok(DataType::List(Box::new(super_type_inner.unwrap())))
668
})
669
}
670
671
/// Set the timezone of a datetime dtype.
672
#[cfg(feature = "timezones")]
673
pub fn map_datetime_dtype_timezone(&self, tz: Option<&TimeZone>) -> PolarsResult<Field> {
674
self.try_map_dtype(|dt| {
675
if let DataType::Datetime(tu, _) = dt {
676
Ok(DataType::Datetime(*tu, tz.cloned()))
677
} else {
678
polars_bail!(op = "replace-time-zone", got = dt, expected = "Datetime");
679
}
680
})
681
}
682
683
pub fn sum_dtype(&self) -> PolarsResult<Field> {
684
use DataType::*;
685
self.map_dtype(|dtype| match dtype {
686
Int8 | UInt8 | Int16 | UInt16 => Int64,
687
Boolean => IDX_DTYPE,
688
dt => dt.clone(),
689
})
690
}
691
692
pub fn nested_sum_type(&self) -> PolarsResult<Field> {
693
let mut first = self.fields[0].clone();
694
use DataType::*;
695
let dt = first.dtype().inner_dtype().cloned().ok_or_else(|| {
696
polars_err!(
697
InvalidOperation:"expected List or Array type, got dtype: {}",
698
first.dtype()
699
)
700
})?;
701
702
match dt {
703
Boolean => first.coerce(IDX_DTYPE),
704
UInt8 | Int8 | Int16 | UInt16 => first.coerce(Int64),
705
_ => first.coerce(dt),
706
}
707
Ok(first)
708
}
709
710
pub fn nested_mean_median_type(&self) -> PolarsResult<Field> {
711
let mut first = self.fields[0].clone();
712
use DataType::*;
713
let dt = first.dtype().inner_dtype().cloned().ok_or_else(|| {
714
polars_err!(
715
InvalidOperation:"expected List or Array type, got dtype: {}",
716
first.dtype()
717
)
718
})?;
719
720
let new_dt = match dt {
721
#[cfg(feature = "dtype-datetime")]
722
Date => Datetime(TimeUnit::Microseconds, None),
723
dt if dt.is_temporal() => dt,
724
#[cfg(feature = "dtype-f16")]
725
Float16 => Float16,
726
Float32 => Float32,
727
_ => Float64,
728
};
729
first.coerce(new_dt);
730
Ok(first)
731
}
732
733
pub(super) fn pow_dtype(&self) -> PolarsResult<Field> {
734
let dtype1 = self.fields[0].dtype();
735
let dtype2 = self.fields[1].dtype();
736
let out_dtype = if dtype1.is_integer() {
737
if dtype2.is_float() { dtype2 } else { dtype1 }
738
} else {
739
dtype1
740
};
741
Ok(Field::new(self.fields[0].name().clone(), out_dtype.clone()))
742
}
743
744
pub(super) fn log_dtype(&self) -> PolarsResult<Field> {
745
let dtype1 = self.fields[0].dtype();
746
let dtype2 = self.fields[1].dtype();
747
let out_dtype = if dtype1.is_float() {
748
dtype1
749
} else if dtype2.is_float() {
750
dtype2
751
} else {
752
&DataType::Float64
753
};
754
Ok(Field::new(self.fields[0].name().clone(), out_dtype.clone()))
755
}
756
757
#[cfg(feature = "extract_jsonpath")]
758
pub fn with_opt_dtype(&self, dtype: Option<DataType>) -> PolarsResult<Field> {
759
let dtype = dtype.unwrap_or_else(|| DataType::Unknown(Default::default()));
760
self.with_dtype(dtype)
761
}
762
763
#[cfg(feature = "replace")]
764
pub fn replace_dtype(&self, return_dtype: Option<DataType>) -> PolarsResult<Field> {
765
let dtype = match return_dtype {
766
Some(dtype) => dtype,
767
None => {
768
let new = &self.fields[2];
769
let default = self.fields.get(3);
770
771
// @HACK: Related to implicit implode see #22149.
772
let inner_dtype = new.dtype().inner_dtype().unwrap_or(new.dtype());
773
774
match default {
775
Some(default) => try_get_supertype(default.dtype(), inner_dtype)?,
776
None => inner_dtype.clone(),
777
}
778
},
779
};
780
self.with_dtype(dtype)
781
}
782
783
fn ensure_satisfies(
784
self,
785
mut f: impl FnMut(usize, &DataType) -> bool,
786
op: &'static str,
787
) -> PolarsResult<Self> {
788
for (i, field) in self.fields.iter().enumerate() {
789
polars_ensure!(
790
f(i, field.dtype()),
791
opidx = op,
792
idx = i,
793
self.fields[i].dtype()
794
);
795
}
796
797
Ok(self)
798
}
799
}
800
801
pub(crate) fn args_to_supertype<D: AsRef<DataType>>(dtypes: &[D]) -> PolarsResult<DataType> {
802
let mut st = dtypes[0].as_ref().clone();
803
for dt in &dtypes[1..] {
804
st = try_get_supertype(&st, dt.as_ref())?
805
}
806
807
match (dtypes[0].as_ref(), &st) {
808
#[cfg(feature = "dtype-categorical")]
809
(cat @ DataType::Categorical(_, _), DataType::String) => st = cat.clone(),
810
_ => {
811
if let DataType::Unknown(kind) = st {
812
match kind {
813
UnknownKind::Float => st = DataType::Float64,
814
UnknownKind::Int(v) => {
815
st = materialize_dyn_int(v).dtype();
816
},
817
UnknownKind::Str => st = DataType::String,
818
_ => {},
819
}
820
}
821
},
822
}
823
824
Ok(st)
825
}
826
827