Path: blob/main/crates/polars-plan/src/plans/aexpr/schema.rs
8424 views
#[cfg(feature = "dtype-decimal")]1use polars_compute::decimal::DEC128_MAX_PREC;2use polars_core::series::arithmetic::NumericListOp;3use polars_utils::format_pl_smallstr;4use recursive::recursive;56use super::*;7use crate::constants::{8POLARS_ELEMENT, POLARS_STRUCTFIELDS, get_literal_name, get_pl_element_name,9get_pl_structfields_name,10};1112fn validate_expr(node: Node, ctx: &ToFieldContext) -> PolarsResult<()> {13ctx.arena.get(node).to_field_impl(ctx).map(|_| ())14}1516#[derive(Debug)]17pub struct ToFieldContext<'a> {18arena: &'a Arena<AExpr>,19schema: &'a Schema,20}2122impl<'a> ToFieldContext<'a> {23pub fn new(arena: &'a Arena<AExpr>, schema: &'a Schema) -> Self {24Self { arena, schema }25}26}2728impl AExpr {29pub fn to_dtype(&self, ctx: &ToFieldContext<'_>) -> PolarsResult<DataType> {30self.to_field(ctx).map(|f| f.dtype)31}3233/// Get Field result of the expression. The schema is the input data. The result will34/// not be coerced (also known as auto-implode): this is the responsibility of the caller.35pub fn to_field(&self, ctx: &ToFieldContext<'_>) -> PolarsResult<Field> {36self.to_field_impl(ctx)37}3839/// Get Field result of the expression. The schema is the input data.40///41/// This is taken as `&mut bool` as for some expressions this is determined by the upper node42/// (e.g. `alias`, `cast`).43#[recursive]44pub fn to_field_impl(&self, ctx: &ToFieldContext) -> PolarsResult<Field> {45use AExpr::*;46use DataType::*;47match self {48Element => ctx49.schema50.get_field(POLARS_ELEMENT)51.ok_or_else(|| polars_err!(invalid_element_use)),5253Len => Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE)),54#[cfg(feature = "dynamic_group_by")]55Rolling { function, .. } => {56let e = ctx.arena.get(*function);57let mut field = e.to_field_impl(ctx)?;58// Implicit implode59if !is_scalar_ae(*function, ctx.arena) {60field.dtype = field.dtype.implode();61}62Ok(field)63},64Over {65function,66partition_by,67order_by,68mapping,69} => {70for node in partition_by {71validate_expr(*node, ctx)?;72}73if let Some((node, _)) = order_by {74validate_expr(*node, ctx)?;75}7677let e = ctx.arena.get(*function);78let mut field = e.to_field_impl(ctx)?;7980if matches!(mapping, WindowMapping::Join) && !is_scalar_ae(*function, ctx.arena) {81field.dtype = field.dtype.implode();82}8384Ok(field)85},86Explode { expr, .. } => {87let field = ctx.arena.get(*expr).to_field_impl(ctx)?;88let field = match field.dtype() {89List(inner) => Field::new(field.name().clone(), *inner.clone()),90#[cfg(feature = "dtype-array")]91Array(inner, ..) => Field::new(field.name().clone(), *inner.clone()),92_ => field,93};9495Ok(field)96},97Column(name) => ctx98.schema99.get_field(name)100.ok_or_else(|| PolarsError::ColumnNotFound(name.to_string().into())),101#[cfg(feature = "dtype-struct")]102StructField(name) => {103let struct_field = ctx104.schema105.get_field(POLARS_STRUCTFIELDS)106.ok_or_else(|| polars_err!(invalid_field_use))?;107let DataType::Struct(fields) = struct_field.dtype() else {108return Err(polars_err!(109InvalidOperation: "expected `Struct` dtype for `with_fields` Expr, got `{}`",110struct_field.dtype()));111};112// @NOTE. Linear search performance is not ideal. An alternative approach113// would be to map each field to a new column with a temporary name (see streaming engine),114// and extend the schema accordingly.115for f in fields {116if f.name() == name {117return Ok(f.clone());118}119}120Err(PolarsError::StructFieldNotFound(name.to_string().into()))121},122Literal(sv) => Ok(match sv {123LiteralValue::Series(s) => s.field().into_owned(),124_ => Field::new(sv.output_column_name(), sv.get_datatype()),125}),126BinaryExpr { left, right, op } => {127use DataType::*;128129let field = match op {130Operator::Lt131| Operator::Gt132| Operator::Eq133| Operator::NotEq134| Operator::LogicalAnd135| Operator::LtEq136| Operator::GtEq137| Operator::NotEqValidity138| Operator::EqValidity139| Operator::LogicalOr => {140let out_field;141let out_name = {142out_field = ctx.arena.get(*left).to_field_impl(ctx)?;143out_field.name()144};145Field::new(out_name.clone(), Boolean)146},147Operator::TrueDivide => get_truediv_field(*left, *right, ctx)?,148_ => get_arithmetic_field(*left, *right, *op, ctx)?,149};150151Ok(field)152},153Sort { expr, .. } => ctx.arena.get(*expr).to_field_impl(ctx),154Gather { expr, idx, .. } => {155validate_expr(*idx, ctx)?;156ctx.arena.get(*expr).to_field_impl(ctx)157},158SortBy { expr, .. } => ctx.arena.get(*expr).to_field_impl(ctx),159Filter { input, by } => {160validate_expr(*by, ctx)?;161ctx.arena.get(*input).to_field_impl(ctx)162},163Agg(agg) => {164use IRAggExpr::*;165match agg {166Max { input: expr, .. }167| Min { input: expr, .. }168| First(expr)169| FirstNonNull(expr)170| Last(expr)171| LastNonNull(expr) => ctx.arena.get(*expr).to_field_impl(ctx),172Item { input: expr, .. } => ctx.arena.get(*expr).to_field_impl(ctx),173Sum(expr) => {174let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;175let dt = match field.dtype() {176String | Binary | BinaryOffset | List(_) => {177polars_bail!(178InvalidOperation: "`sum` operation not supported for dtype `{}`",179field.dtype()180)181},182#[cfg(feature = "dtype-array")]183Array(_, _) => {184polars_bail!(185InvalidOperation: "`sum` operation not supported for dtype `{}`",186field.dtype()187)188},189#[cfg(feature = "dtype-struct")]190Struct(_) => {191polars_bail!(192InvalidOperation: "`sum` operation not supported for dtype `{}`",193field.dtype()194)195},196Boolean => Some(IDX_DTYPE),197UInt8 | Int8 | Int16 | UInt16 => Some(Int64),198_ => None,199};200if let Some(dt) = dt {201field.coerce(dt);202}203Ok(field)204},205Median(expr) => {206let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];207let mapper = FieldsMapper::new(&field);208mapper.moment_dtype()209},210Mean(expr) => {211let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];212let mapper = FieldsMapper::new(&field);213mapper.moment_dtype()214},215Implode(expr) => {216let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;217field.coerce(DataType::List(field.dtype().clone().into()));218Ok(field)219},220Std(expr, _) => {221let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];222let mapper = FieldsMapper::new(&field);223mapper.moment_dtype()224},225Var(expr, _) => {226let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];227let mapper = FieldsMapper::new(&field);228mapper.var_dtype()229},230NUnique(expr) => {231let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;232field.coerce(IDX_DTYPE);233Ok(field)234},235Count { input, .. } => {236let mut field = ctx.arena.get(*input).to_field_impl(ctx)?;237field.coerce(IDX_DTYPE);238Ok(field)239},240AggGroups(expr) => {241let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?;242field.coerce(IDX_DTYPE.implode());243Ok(field)244},245Quantile { expr, .. } => {246let field = [ctx.arena.get(*expr).to_field_impl(ctx)?];247let mapper = FieldsMapper::new(&field);248mapper.moment_dtype()249},250}251},252Cast { expr, dtype, .. } => {253let field = ctx.arena.get(*expr).to_field_impl(ctx)?;254Ok(Field::new(field.name().clone(), dtype.clone()))255},256Ternary { truthy, falsy, .. } => {257// During aggregation:258// left: col(foo): list<T> nesting: 1259// right; col(foo).first(): T nesting: 0260// col(foo) + col(foo).first() will have nesting 1 as we still maintain the groups list.261let mut truthy = ctx.arena.get(*truthy).to_field_impl(ctx)?;262let falsy = ctx.arena.get(*falsy).to_field_impl(ctx)?;263264let st = if let DataType::Null = *truthy.dtype() {265falsy.dtype().clone()266} else {267try_get_supertype(truthy.dtype(), falsy.dtype())?268};269270truthy.coerce(st);271Ok(truthy)272},273AnonymousFunction {274input,275function,276fmt_str,277..278} => {279let fields = func_args_to_fields(input, ctx)?;280polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", fmt_str);281let function = function.clone().materialize()?;282let out = function.get_field(ctx.schema, &fields)?;283Ok(out)284},285AnonymousAgg {286input,287function,288fmt_str,289..290} => {291let fields = func_args_to_fields(input, ctx)?;292polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", fmt_str);293let function = function.clone().materialize()?;294let out = function.get_field(ctx.schema, &fields)?;295Ok(out)296},297Eval {298expr,299evaluation,300variant,301} => {302let field = ctx.arena.get(*expr).to_field_impl(ctx)?;303304let element_dtype = variant.element_dtype(field.dtype())?;305let mut evaluation_schema = ctx.schema.clone();306evaluation_schema.insert(get_pl_element_name(), element_dtype.clone());307let mut output_field = ctx308.arena309.get(*evaluation)310.to_field_impl(&ToFieldContext::new(ctx.arena, &evaluation_schema))?;311output_field.dtype = output_field.dtype.materialize_unknown(false)?;312let eval_is_scalar = is_scalar_ae(*evaluation, ctx.arena);313314output_field.dtype =315variant.output_dtype(field.dtype(), output_field.dtype, eval_is_scalar)?;316output_field.name = field.name;317318Ok(output_field)319},320#[cfg(feature = "dtype-struct")]321StructEval { expr, evaluation } => {322let struct_field = ctx.arena.get(*expr).to_field_impl(ctx)?;323let mut evaluation_schema = ctx.schema.clone();324evaluation_schema.insert(get_pl_structfields_name(), struct_field.dtype().clone());325326let eval_fields = func_args_to_fields(327evaluation,328&ToFieldContext::new(ctx.arena, &evaluation_schema),329)?;330331// Merge evaluation fields into the expr Struct332if let DataType::Struct(expr_fields) = struct_field.dtype() {333let mut fields_map =334PlIndexMap::with_capacity(expr_fields.len() + eval_fields.len());335for field in expr_fields {336fields_map.insert(field.name(), field.dtype());337}338for field in &eval_fields {339fields_map.insert(field.name(), field.dtype());340}341let dtype = DataType::Struct(342fields_map343.iter()344.map(|(&name, &dtype)| Field::new(name.clone(), dtype.clone()))345.collect(),346);347let mut out = struct_field.clone();348out.coerce(dtype);349Ok(out)350} else {351let dt = struct_field.dtype();352polars_bail!(op = "with_fields", got = dt, expected = "Struct")353}354},355Function {356function,357input,358options: _,359} => {360#[cfg(feature = "strings")]361{362if input.is_empty()363&& matches!(364&function,365IRFunctionExpr::StringExpr(IRStringFunction::Format { .. })366)367{368return Ok(Field::new(get_literal_name(), DataType::String));369}370}371372let fields = func_args_to_fields(input, ctx)?;373polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function);374let out = function.get_field(ctx.schema, &fields)?;375376Ok(out)377},378Slice {379input,380offset,381length,382} => {383validate_expr(*offset, ctx)?;384validate_expr(*length, ctx)?;385386ctx.arena.get(*input).to_field_impl(ctx)387},388}389}390391pub fn to_name(&self, expr_arena: &Arena<AExpr>) -> PlSmallStr {392use AExpr::*;393use IRAggExpr::*;394match self {395Element => PlSmallStr::EMPTY,396Len => crate::constants::get_len_name(),397#[cfg(feature = "dynamic_group_by")]398Rolling {399function,400index_column: _,401period: _,402offset: _,403closed_window: _,404} => expr_arena.get(*function).to_name(expr_arena),405Over {406function: expr,407partition_by: _,408order_by: _,409mapping: _,410}411| BinaryExpr { left: expr, .. }412| Explode { expr, .. }413| Sort { expr, .. }414| Gather { expr, .. }415| SortBy { expr, .. }416| Filter { input: expr, .. }417| Cast { expr, .. }418| Ternary { truthy: expr, .. }419| Eval { expr, .. }420| Slice { input: expr, .. }421| Agg(Min { input: expr, .. })422| Agg(Max { input: expr, .. })423| Agg(First(expr))424| Agg(FirstNonNull(expr))425| Agg(Last(expr))426| Agg(LastNonNull(expr))427| Agg(Item { input: expr, .. })428| Agg(Sum(expr))429| Agg(Median(expr))430| Agg(Mean(expr))431| Agg(Implode(expr))432| Agg(Std(expr, _))433| Agg(Var(expr, _))434| Agg(NUnique(expr))435| Agg(Count { input: expr, .. })436| Agg(AggGroups(expr))437| Agg(Quantile { expr, .. }) => expr_arena.get(*expr).to_name(expr_arena),438AnonymousFunction { input, fmt_str, .. } | AnonymousAgg { input, fmt_str, .. } => {439if input.is_empty() {440fmt_str.as_ref().clone()441} else {442input[0].output_name().clone()443}444},445#[cfg(feature = "dtype-struct")]446StructEval { expr, .. } => expr_arena.get(*expr).to_name(expr_arena),447Function {448input, function, ..449} => match function.output_name().and_then(|v| v.into_inner()) {450Some(name) => name,451None if input.is_empty() => format_pl_smallstr!("{}", &function),452None => input[0].output_name().clone(),453},454Column(name) => name.clone(),455#[cfg(feature = "dtype-struct")]456StructField(name) => name.clone(),457Literal(lv) => lv.output_column_name().clone(),458}459}460}461462fn func_args_to_fields(input: &[ExprIR], ctx: &ToFieldContext) -> PolarsResult<Vec<Field>> {463input464.iter()465.map(|e| {466ctx.arena.get(e.node()).to_field_impl(ctx).map(|mut field| {467field.name = e.output_name().clone();468field469})470})471.collect()472}473474#[allow(clippy::too_many_arguments)]475fn get_arithmetic_field(476left: Node,477right: Node,478op: Operator,479ctx: &ToFieldContext,480) -> PolarsResult<Field> {481use DataType::*;482let left_ae = ctx.arena.get(left);483let right_ae = ctx.arena.get(right);484485// don't traverse tree until strictly needed. Can have terrible performance.486// # 3210487488// take the left field as a whole.489// don't take dtype and name separate as that splits the tree every node490// leading to quadratic behavior. # 4736491//492// further right_type is only determined when needed.493let mut left_field = left_ae.to_field_impl(ctx)?;494495let super_type = match op {496Operator::Minus => {497let right_type = right_ae.to_field_impl(ctx)?.dtype;498match (&left_field.dtype, &right_type) {499#[cfg(feature = "dtype-struct")]500(Struct(_), Struct(_)) => {501return Ok(left_field);502},503// This matches the engine output. TODO: revisit pending resolution of GH issue #23797504#[cfg(feature = "dtype-struct")]505(Struct(_), r) if r.is_numeric() => {506return Ok(left_field);507},508(Duration(_), Datetime(_, _))509| (Datetime(_, _), Duration(_))510| (Duration(_), Date)511| (Date, Duration(_))512| (Duration(_), Time)513| (Time, Duration(_)) => try_get_supertype(left_field.dtype(), &right_type)?,514(Datetime(tu, _), Date) | (Date, Datetime(tu, _)) => Duration(*tu),515// T - T != T if T is a datetime / date516(Datetime(tul, _), Datetime(tur, _)) => Duration(get_time_units(tul, tur)),517(_, Datetime(_, _)) | (Datetime(_, _), _) => {518polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)519},520(Date, Date) => Duration(TimeUnit::Microseconds),521(_, Date) | (Date, _) => {522polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)523},524(Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)),525(_, Duration(_)) | (Duration(_), _) => {526polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)527},528(Time, Time) => Duration(TimeUnit::Nanoseconds),529(_, Time) | (Time, _) => {530polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)531},532(l @ List(a), r @ List(b))533if ![a, b]534.into_iter()535.all(|x| x.is_supported_list_arithmetic_input()) =>536{537polars_bail!(538InvalidOperation:539"cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",540"sub", l, r,541)542},543(list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {544// TODO: This should not use `try_get_supertype()`! It should instead recursively use the enclosing match block.545// Otherwise we will silently permit addition operations between logical types (see above).546// This currently doesn't cause any problems because the list arithmetic implementation checks and raises errors547// if the leaf types aren't numeric, but it means we don't raise an error until execution and the DSL schema548// may be incorrect.549list_dtype.cast_leaf(NumericListOp::sub().try_get_leaf_supertype(550list_dtype.leaf_dtype(),551other_dtype.leaf_dtype(),552)?)553},554#[cfg(feature = "dtype-array")]555(list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {556list_dtype.cast_leaf(try_get_supertype(557list_dtype.leaf_dtype(),558other_dtype.leaf_dtype(),559)?)560},561#[cfg(feature = "dtype-decimal")]562(Decimal(_, scale_left), Decimal(_, scale_right)) => {563Decimal(DEC128_MAX_PREC, *scale_left.max(scale_right))564},565(left, right) => try_get_supertype(left, right)?,566}567},568Operator::Plus => {569let right_type = right_ae.to_field_impl(ctx)?.dtype;570match (&left_field.dtype, &right_type) {571#[cfg(feature = "dtype-struct")]572(Struct(_), Struct(_)) => {573return Ok(left_field);574},575// This matches the engine output. TODO: revisit pending resolution of GH issue #23797576#[cfg(feature = "dtype-struct")]577(Struct(_), r) if r.is_numeric() => {578return Ok(left_field);579},580(Duration(_), Datetime(_, _))581| (Datetime(_, _), Duration(_))582| (Duration(_), Date)583| (Date, Duration(_))584| (Duration(_), Time)585| (Time, Duration(_)) => try_get_supertype(left_field.dtype(), &right_type)?,586(_, Datetime(_, _))587| (Datetime(_, _), _)588| (_, Date)589| (Date, _)590| (Time, _)591| (_, Time) => {592polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)593},594(Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)),595(_, Duration(_)) | (Duration(_), _) => {596polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)597},598(Boolean, Boolean) => IDX_DTYPE,599(l @ List(a), r @ List(b))600if ![a, b]601.into_iter()602.all(|x| x.is_supported_list_arithmetic_input()) =>603{604polars_bail!(605InvalidOperation:606"cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",607"add", l, r,608)609},610(list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {611list_dtype.cast_leaf(NumericListOp::add().try_get_leaf_supertype(612list_dtype.leaf_dtype(),613other_dtype.leaf_dtype(),614)?)615},616#[cfg(feature = "dtype-array")]617(list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {618list_dtype.cast_leaf(try_get_supertype(619list_dtype.leaf_dtype(),620other_dtype.leaf_dtype(),621)?)622},623#[cfg(feature = "dtype-decimal")]624(Decimal(_, scale_left), Decimal(_, scale_right)) => {625Decimal(DEC128_MAX_PREC, *scale_left.max(scale_right))626},627(left, right) => try_get_supertype(left, right)?,628}629},630_ => {631let right_type = right_ae.to_field_impl(ctx)?.dtype;632633match (&left_field.dtype, &right_type) {634#[cfg(feature = "dtype-struct")]635(Struct(_), Struct(_)) => {636return Ok(left_field);637},638// This matches the engine output. TODO: revisit pending resolution of GH issue #23797639#[cfg(feature = "dtype-struct")]640(Struct(_), r) if r.is_numeric() => {641return Ok(left_field);642},643(Datetime(_, _), _)644| (_, Datetime(_, _))645| (Time, _)646| (_, Time)647| (Date, _)648| (_, Date) => {649polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)650},651(Duration(_), Duration(_)) => {652// True divide handled somewhere else653polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)654},655(l, Duration(_)) if l.is_primitive_numeric() => match op {656Operator::Multiply => {657left_field.coerce(right_type);658return Ok(left_field);659},660_ => {661polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)662},663},664(Duration(_), r) if r.is_primitive_numeric() => match op {665Operator::Multiply => {666return Ok(left_field);667},668_ => {669polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)670},671},672#[cfg(feature = "dtype-decimal")]673(Decimal(_, scale_left), Decimal(_, scale_right)) => {674let dtype = Decimal(DEC128_MAX_PREC, *scale_left.max(scale_right));675left_field.coerce(dtype);676return Ok(left_field);677},678679(l @ List(a), r @ List(b))680if ![a, b]681.into_iter()682.all(|x| x.is_supported_list_arithmetic_input()) =>683{684polars_bail!(685InvalidOperation:686"cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",687op, l, r,688)689},690// List<->primitive operations can be done directly after casting the to the primitive691// supertype for the primitive values on both sides.692(list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {693let dtype = list_dtype.cast_leaf(try_get_supertype(694list_dtype.leaf_dtype(),695other_dtype.leaf_dtype(),696)?);697left_field.coerce(dtype);698return Ok(left_field);699},700#[cfg(feature = "dtype-array")]701(list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {702let dtype = list_dtype.cast_leaf(try_get_supertype(703list_dtype.leaf_dtype(),704other_dtype.leaf_dtype(),705)?);706left_field.coerce(dtype);707return Ok(left_field);708},709_ => {710// Avoid needlessly type casting numeric columns during arithmetic711// with literals.712if (left_field.dtype.is_integer() && right_type.is_integer())713|| (left_field.dtype.is_float() && right_type.is_float())714{715match (left_ae, right_ae) {716(AExpr::Literal(_), AExpr::Literal(_)) => {},717(AExpr::Literal(_), _) if left_field.dtype.is_unknown() => {718// literal will be coerced to match right type719left_field.coerce(right_type);720return Ok(left_field);721},722(_, AExpr::Literal(_)) if right_type.is_unknown() => {723// literal will be coerced to match right type724return Ok(left_field);725},726_ => {},727}728}729},730}731732try_get_supertype(&left_field.dtype, &right_type)?733},734};735736left_field.coerce(super_type);737Ok(left_field)738}739740fn get_truediv_field(left: Node, right: Node, ctx: &ToFieldContext) -> PolarsResult<Field> {741let mut left_field = ctx.arena.get(left).to_field_impl(ctx)?;742let right_field = ctx.arena.get(right).to_field_impl(ctx)?;743let out_type = get_truediv_dtype(left_field.dtype(), right_field.dtype())?;744left_field.coerce(out_type);745Ok(left_field)746}747748fn get_truediv_dtype(left_dtype: &DataType, right_dtype: &DataType) -> PolarsResult<DataType> {749use DataType::*;750751// TODO: Re-investigate this. A lot of "_" is being used on the RHS match because this code752// originally (mostly) only looked at the LHS dtype.753let out_type = match (left_dtype, right_dtype) {754#[cfg(feature = "dtype-struct")]755(Struct(a), Struct(b)) => {756polars_ensure!(a.len() == b.len() || b.len() == 1,757InvalidOperation: "cannot {} two structs of different length (left: {}, right: {})",758"div", a.len(), b.len()759);760let mut fields = Vec::with_capacity(a.len());761// In case b.len() == 1, we broadcast the first field (b[0]).762// Safety is assured by the constraints above.763let b_iter = (0..a.len()).map(|i| b.get(i.min(b.len() - 1)).unwrap());764for (left, right) in a.iter().zip(b_iter) {765let name = left.name.clone();766let (left, right) = (left.dtype(), right.dtype());767if !(left.is_numeric() && right.is_numeric()) {768polars_bail!(InvalidOperation:769"cannot {} two structs with non-numeric fields: (left: {}, right: {})",770"div", left, right,)771};772let field = Field::new(name, get_truediv_dtype(left, right)?);773fields.push(field);774}775Struct(fields)776},777#[cfg(feature = "dtype-struct")]778(Struct(a), n) if n.is_numeric() => {779let mut fields = Vec::with_capacity(a.len());780for left in a.iter() {781let name = left.name.clone();782let left = left.dtype();783if !(left.is_numeric()) {784polars_bail!(InvalidOperation:785"cannot {} a struct with non-numeric field: (left: {})",786"div", left)787};788let field = Field::new(name, get_truediv_dtype(left, n)?);789fields.push(field);790}791Struct(fields)792},793(l @ List(a), r @ List(b))794if ![a, b]795.into_iter()796.all(|x| x.is_supported_list_arithmetic_input()) =>797{798polars_bail!(799InvalidOperation:800"cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",801"div", l, r,802)803},804(list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {805let dtype = get_truediv_dtype(list_dtype.leaf_dtype(), other_dtype.leaf_dtype())?;806list_dtype.cast_leaf(dtype)807},808#[cfg(feature = "dtype-array")]809(list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {810let dtype = get_truediv_dtype(list_dtype.leaf_dtype(), other_dtype.leaf_dtype())?;811list_dtype.cast_leaf(dtype)812},813#[cfg(feature = "dtype-f16")]814(Boolean, Float16) => Float16,815(Boolean, Float32) => Float32,816(Boolean, b) if b.is_numeric() => Float64,817(Boolean, Boolean) => Float64,818#[cfg(all(feature = "dtype-f16", feature = "dtype-u8"))]819(Float16, UInt8 | Int8) => Float16,820#[cfg(all(feature = "dtype-f16", feature = "dtype-u16"))]821(Float16, UInt16 | Int16) => Float32,822#[cfg(feature = "dtype-f16")]823(Float16, Unknown(UnknownKind::Int(_))) => Float16,824#[cfg(feature = "dtype-f16")]825(Float16, other) if other.is_integer() => Float64,826#[cfg(feature = "dtype-f16")]827(Float16, Float32) => Float32,828#[cfg(feature = "dtype-f16")]829(Float16, Float64) => Float64,830#[cfg(feature = "dtype-f16")]831(Float16, _) => Float16,832#[cfg(feature = "dtype-u8")]833(Float32, UInt8 | Int8) => Float32,834#[cfg(feature = "dtype-u16")]835(Float32, UInt16 | Int16) => Float32,836(Float32, Unknown(UnknownKind::Int(_))) => Float32,837(Float32, other) if other.is_integer() => Float64,838(Float32, Float64) => Float64,839(Float32, _) => Float32,840(String, _) | (_, String) => polars_bail!(841InvalidOperation: "division with 'String' datatypes is not allowed"842),843#[cfg(feature = "dtype-decimal")]844(Decimal(_, scale_left), Decimal(_, scale_right)) => {845Decimal(DEC128_MAX_PREC, *scale_left.max(scale_right))846},847#[cfg(all(feature = "dtype-u8", feature = "dtype-f16"))]848(UInt8 | Int8, Float16) => Float16,849#[cfg(all(feature = "dtype-u16", feature = "dtype-f16"))]850(UInt16 | Int16, Float16) => Float32,851#[cfg(feature = "dtype-u8")]852(UInt8 | Int8, Float32) => Float32,853#[cfg(feature = "dtype-u16")]854(UInt16 | Int16, Float32) => Float32,855(dt, _) if dt.is_primitive_numeric() => Float64,856#[cfg(feature = "dtype-duration")]857(Duration(_), Duration(_)) => Float64,858#[cfg(feature = "dtype-duration")]859(Duration(_), dt) if dt.is_primitive_numeric() => left_dtype.clone(),860#[cfg(feature = "dtype-duration")]861(Duration(_), dt) => {862polars_bail!(InvalidOperation: "true division of {} with {} is not allowed", left_dtype, dt)863},864#[cfg(feature = "dtype-datetime")]865(Datetime(_, _), _) => {866polars_bail!(InvalidOperation: "division of 'Datetime' datatype is not allowed")867},868#[cfg(feature = "dtype-time")]869(Time, _) => polars_bail!(InvalidOperation: "division of 'Time' datatype is not allowed"),870#[cfg(feature = "dtype-date")]871(Date, _) => polars_bail!(InvalidOperation: "division of 'Date' datatype is not allowed"),872// we don't know what to do here, best return the dtype873(dt, _) => dt.clone(),874};875Ok(out_type)876}877878879