Path: blob/main/crates/polars-plan/src/plans/conversion/dsl_to_ir/functions.rs
8503 views
use arrow::legacy::error::PolarsResult;1use polars_utils::arena::Node;2use polars_utils::format_pl_smallstr;3use polars_utils::option::OptionTry;45use super::expr_to_ir::ExprToIRContext;6use super::*;7use crate::constants::get_literal_name;8use crate::dsl::{Expr, FunctionExpr};9use crate::plans::conversion::dsl_to_ir::expr_to_ir::to_expr_irs;10use crate::plans::{AExpr, IRFunctionExpr};1112pub(super) fn convert_functions(13input: Vec<Expr>,14function: FunctionExpr,15ctx: &mut ExprToIRContext,16) -> PolarsResult<(Node, PlSmallStr)> {17use {FunctionExpr as F, IRFunctionExpr as I};1819// Converts inputs20let input_is_empty = input.is_empty();21let e = to_expr_irs(input, ctx)?;22let mut set_elementwise = false;2324// Return before converting inputs25let ir_function = match function {26#[cfg(feature = "dtype-array")]27F::ArrayExpr(array_function) => {28use {ArrayFunction as A, IRArrayFunction as IA};29I::ArrayExpr(match array_function {30A::Length => IA::Length,31A::Min => IA::Min,32A::Max => IA::Max,33A::Sum => IA::Sum,34A::ToList => IA::ToList,35A::Unique(stable) => IA::Unique(stable),36A::NUnique => IA::NUnique,37A::Std(v) => IA::Std(v),38A::Var(v) => IA::Var(v),39A::Mean => IA::Mean,40A::Median => IA::Median,41#[cfg(feature = "array_any_all")]42A::Any => IA::Any,43#[cfg(feature = "array_any_all")]44A::All => IA::All,45A::Sort(sort_options) => IA::Sort(sort_options),46A::Reverse => IA::Reverse,47A::ArgMin => IA::ArgMin,48A::ArgMax => IA::ArgMax,49A::Get(v) => IA::Get(v),50A::Join(v) => IA::Join(v),51#[cfg(feature = "is_in")]52A::Contains { nulls_equal } => IA::Contains { nulls_equal },53#[cfg(feature = "array_count")]54A::CountMatches => IA::CountMatches,55A::Shift => IA::Shift,56A::Explode(options) => IA::Explode(options),57A::Concat => IA::Concat,58A::Slice(offset, length) => IA::Slice(offset, length),59#[cfg(feature = "array_to_struct")]60A::ToStruct(ng) => IA::ToStruct(ng),61})62},63F::BinaryExpr(binary_function) => {64use {BinaryFunction as B, IRBinaryFunction as IB};65I::BinaryExpr(match binary_function {66B::Contains => IB::Contains,67B::StartsWith => IB::StartsWith,68B::EndsWith => IB::EndsWith,69#[cfg(feature = "binary_encoding")]70B::HexDecode(v) => IB::HexDecode(v),71#[cfg(feature = "binary_encoding")]72B::HexEncode => IB::HexEncode,73#[cfg(feature = "binary_encoding")]74B::Base64Decode(v) => IB::Base64Decode(v),75#[cfg(feature = "binary_encoding")]76B::Base64Encode => IB::Base64Encode,77B::Size => IB::Size,78#[cfg(feature = "binary_encoding")]79B::Reinterpret(dtype_expr, v) => {80let dtype = dtype_expr.into_datatype(ctx.schema)?;81let can_reinterpret_to =82|dt: &DataType| dt.is_primitive_numeric() || dt.is_temporal();83polars_ensure!(84can_reinterpret_to(&dtype) || (85dtype.is_array() && dtype.inner_dtype().map(can_reinterpret_to) == Some(true)86),87InvalidOperation:88"cannot reinterpret binary to dtype {:?}. Only numeric or temporal dtype, or Arrays of these, are supported. Hint: To reinterpret to a nested Array, first reinterpret to a linear Array, and then use reshape",89dtype90);91IB::Reinterpret(dtype, v)92},93B::Slice => IB::Slice,94B::Head => IB::Head,95B::Tail => IB::Tail,96B::Get(null_on_oob) => IB::Get(null_on_oob),97})98},99#[cfg(feature = "dtype-categorical")]100F::Categorical(categorical_function) => {101use {CategoricalFunction as C, IRCategoricalFunction as IC};102I::Categorical(match categorical_function {103C::GetCategories => IC::GetCategories,104#[cfg(feature = "strings")]105C::LenBytes => IC::LenBytes,106#[cfg(feature = "strings")]107C::LenChars => IC::LenChars,108#[cfg(feature = "strings")]109C::StartsWith(v) => IC::StartsWith(v),110#[cfg(feature = "strings")]111C::EndsWith(v) => IC::EndsWith(v),112#[cfg(feature = "strings")]113C::Slice(s, e) => IC::Slice(s, e),114})115},116#[cfg(feature = "dtype-extension")]117F::Extension(extension_function) => {118use {ExtensionFunction as E, IRExtensionFunction as IE};119I::Extension(match extension_function {120E::To(dtype) => {121let concrete_dtype = dtype.into_datatype(ctx.schema)?;122polars_ensure!(matches!(concrete_dtype, DataType::Extension(_, _)),123InvalidOperation: "ext.to() requires an Extension dtype, got {concrete_dtype:?}"124);125IE::To(concrete_dtype)126},127E::Storage => IE::Storage,128})129},130F::ListExpr(list_function) => {131use {IRListFunction as IL, ListFunction as L};132I::ListExpr(match list_function {133L::Concat => IL::Concat,134#[cfg(feature = "is_in")]135L::Contains { nulls_equal } => IL::Contains { nulls_equal },136#[cfg(feature = "list_drop_nulls")]137L::DropNulls => IL::DropNulls,138#[cfg(feature = "list_sample")]139L::Sample {140is_fraction,141with_replacement,142shuffle,143seed,144} => IL::Sample {145is_fraction,146with_replacement,147shuffle,148seed,149},150L::Slice => IL::Slice,151L::Shift => IL::Shift,152L::Get(v) => IL::Get(v),153#[cfg(feature = "list_gather")]154L::Gather(v) => IL::Gather(v),155#[cfg(feature = "list_gather")]156L::GatherEvery => IL::GatherEvery,157#[cfg(feature = "list_count")]158L::CountMatches => IL::CountMatches,159L::Sum => IL::Sum,160L::Length => IL::Length,161L::Max => IL::Max,162L::Min => IL::Min,163L::Mean => IL::Mean,164L::Median => IL::Median,165L::Std(v) => IL::Std(v),166L::Var(v) => IL::Var(v),167L::ArgMin => IL::ArgMin,168L::ArgMax => IL::ArgMax,169#[cfg(feature = "diff")]170L::Diff { n, null_behavior } => IL::Diff { n, null_behavior },171L::Sort(sort_options) => IL::Sort(sort_options),172L::Reverse => IL::Reverse,173L::Unique(v) => IL::Unique(v),174L::NUnique => IL::NUnique,175#[cfg(feature = "list_sets")]176L::SetOperation(set_operation) => IL::SetOperation(set_operation),177#[cfg(feature = "list_any_all")]178L::Any => IL::Any,179#[cfg(feature = "list_any_all")]180L::All => IL::All,181L::Join(v) => IL::Join(v),182#[cfg(feature = "dtype-array")]183L::ToArray(v) => IL::ToArray(v),184#[cfg(feature = "list_to_struct")]185L::ToStruct(list_to_struct_args) => IL::ToStruct(list_to_struct_args),186})187},188#[cfg(feature = "strings")]189F::StringExpr(string_function) => {190use {IRStringFunction as IS, StringFunction as S};191I::StringExpr(match string_function {192S::Format { format, insertions } => {193if input_is_empty {194polars_ensure!(195insertions.is_empty(),196ComputeError: "StringFormat didn't get any inputs, format: \"{}\"",197format198);199200let out = ctx201.arena202.add(AExpr::Literal(LiteralValue::Scalar(Scalar::from(format))));203204return Ok((out, get_literal_name()));205} else {206IS::Format { format, insertions }207}208},209#[cfg(feature = "concat_str")]210S::ConcatHorizontal {211delimiter,212ignore_nulls,213} => IS::ConcatHorizontal {214delimiter,215ignore_nulls,216},217#[cfg(feature = "concat_str")]218S::ConcatVertical {219delimiter,220ignore_nulls,221} => IS::ConcatVertical {222delimiter,223ignore_nulls,224},225#[cfg(feature = "regex")]226S::Contains { literal, strict } => IS::Contains { literal, strict },227S::CountMatches(v) => IS::CountMatches(v),228S::EndsWith => IS::EndsWith,229S::Extract(v) => IS::Extract(v),230S::ExtractAll => IS::ExtractAll,231#[cfg(feature = "extract_groups")]232S::ExtractGroups { dtype, pat } => IS::ExtractGroups { dtype, pat },233#[cfg(feature = "regex")]234S::Find { literal, strict } => IS::Find { literal, strict },235#[cfg(feature = "string_to_integer")]236S::ToInteger { dtype, strict } => IS::ToInteger { dtype, strict },237S::LenBytes => IS::LenBytes,238S::LenChars => IS::LenChars,239S::Lowercase => IS::Lowercase,240#[cfg(feature = "extract_jsonpath")]241S::JsonDecode(dtype) => IS::JsonDecode(dtype.into_datatype(ctx.schema)?),242#[cfg(feature = "extract_jsonpath")]243S::JsonPathMatch => IS::JsonPathMatch,244#[cfg(feature = "regex")]245S::Replace { n, literal } => IS::Replace { n, literal },246#[cfg(feature = "string_normalize")]247S::Normalize { form } => IS::Normalize { form },248#[cfg(feature = "string_reverse")]249S::Reverse => IS::Reverse,250#[cfg(feature = "string_pad")]251S::PadStart { fill_char } => IS::PadStart { fill_char },252#[cfg(feature = "string_pad")]253S::PadEnd { fill_char } => IS::PadEnd { fill_char },254S::Slice => IS::Slice,255S::Head => IS::Head,256S::Tail => IS::Tail,257#[cfg(feature = "string_encoding")]258S::HexEncode => IS::HexEncode,259#[cfg(feature = "binary_encoding")]260S::HexDecode(v) => IS::HexDecode(v),261#[cfg(feature = "string_encoding")]262S::Base64Encode => IS::Base64Encode,263#[cfg(feature = "binary_encoding")]264S::Base64Decode(v) => IS::Base64Decode(v),265S::StartsWith => IS::StartsWith,266S::StripChars => IS::StripChars,267S::StripCharsStart => IS::StripCharsStart,268S::StripCharsEnd => IS::StripCharsEnd,269S::StripPrefix => IS::StripPrefix,270S::StripSuffix => IS::StripSuffix,271#[cfg(feature = "dtype-struct")]272S::SplitExact { n, inclusive } => IS::SplitExact { n, inclusive },273#[cfg(feature = "dtype-struct")]274S::SplitN(v) => IS::SplitN(v),275#[cfg(feature = "regex")]276S::SplitRegex { inclusive, strict } => IS::SplitRegex { inclusive, strict },277#[cfg(feature = "temporal")]278S::Strptime(data_type, strptime_options) => {279let is_column_independent = is_column_independent_aexpr(e[0].node(), ctx.arena);280set_elementwise = is_column_independent;281let dtype = data_type.into_datatype(ctx.schema)?;282polars_ensure!(283matches!(dtype,284DataType::Date |285DataType::Datetime(_, _) |286DataType::Time287),288InvalidOperation: "`strptime` expects a `date`, `datetime` or `time` got {dtype}"289);290IS::Strptime(dtype, strptime_options)291},292S::Split(v) => IS::Split(v),293#[cfg(feature = "dtype-decimal")]294S::ToDecimal { scale } => IS::ToDecimal { scale },295#[cfg(feature = "nightly")]296S::Titlecase => IS::Titlecase,297S::Uppercase => IS::Uppercase,298#[cfg(feature = "string_pad")]299S::ZFill => IS::ZFill,300#[cfg(feature = "find_many")]301S::ContainsAny {302ascii_case_insensitive,303} => IS::ContainsAny {304ascii_case_insensitive,305},306#[cfg(feature = "find_many")]307S::ReplaceMany {308ascii_case_insensitive,309leftmost,310} => IS::ReplaceMany {311ascii_case_insensitive,312leftmost,313},314#[cfg(feature = "find_many")]315S::ExtractMany {316ascii_case_insensitive,317overlapping,318leftmost,319} => IS::ExtractMany {320ascii_case_insensitive,321overlapping,322leftmost,323},324#[cfg(feature = "find_many")]325S::FindMany {326ascii_case_insensitive,327overlapping,328leftmost,329} => IS::FindMany {330ascii_case_insensitive,331overlapping,332leftmost,333},334#[cfg(feature = "regex")]335S::EscapeRegex => IS::EscapeRegex,336})337},338#[cfg(feature = "dtype-struct")]339F::StructExpr(struct_function) => {340use {IRStructFunction as IS, StructFunction as S};341I::StructExpr(match struct_function {342S::FieldByName(pl_small_str) => IS::FieldByName(pl_small_str),343S::RenameFields(pl_small_strs) => IS::RenameFields(pl_small_strs),344S::PrefixFields(pl_small_str) => IS::PrefixFields(pl_small_str),345S::SuffixFields(pl_small_str) => IS::SuffixFields(pl_small_str),346S::SelectFields(_) => unreachable!("handled by expression expansion"),347#[cfg(feature = "json")]348S::JsonEncode => IS::JsonEncode,349S::MapFieldNames(f) => IS::MapFieldNames(f),350})351},352#[cfg(feature = "temporal")]353F::TemporalExpr(temporal_function) => {354use {IRTemporalFunction as IT, TemporalFunction as T};355I::TemporalExpr(match temporal_function {356T::Millennium => IT::Millennium,357T::Century => IT::Century,358T::Year => IT::Year,359T::IsLeapYear => IT::IsLeapYear,360T::IsoYear => IT::IsoYear,361T::Quarter => IT::Quarter,362T::Month => IT::Month,363T::DaysInMonth => IT::DaysInMonth,364T::Week => IT::Week,365T::WeekDay => IT::WeekDay,366T::Day => IT::Day,367T::OrdinalDay => IT::OrdinalDay,368T::Time => IT::Time,369T::Date => IT::Date,370T::Datetime => IT::Datetime,371#[cfg(feature = "dtype-duration")]372T::Duration(time_unit) => IT::Duration(time_unit),373T::Hour => IT::Hour,374T::Minute => IT::Minute,375T::Second => IT::Second,376T::Millisecond => IT::Millisecond,377T::Microsecond => IT::Microsecond,378T::Nanosecond => IT::Nanosecond,379#[cfg(feature = "dtype-duration")]380T::TotalDays { fractional } => IT::TotalDays { fractional },381#[cfg(feature = "dtype-duration")]382T::TotalHours { fractional } => IT::TotalHours { fractional },383#[cfg(feature = "dtype-duration")]384T::TotalMinutes { fractional } => IT::TotalMinutes { fractional },385#[cfg(feature = "dtype-duration")]386T::TotalSeconds { fractional } => IT::TotalSeconds { fractional },387#[cfg(feature = "dtype-duration")]388T::TotalMilliseconds { fractional } => IT::TotalMilliseconds { fractional },389#[cfg(feature = "dtype-duration")]390T::TotalMicroseconds { fractional } => IT::TotalMicroseconds { fractional },391#[cfg(feature = "dtype-duration")]392T::TotalNanoseconds { fractional } => IT::TotalNanoseconds { fractional },393T::ToString(v) => IT::ToString(v),394T::CastTimeUnit(time_unit) => IT::CastTimeUnit(time_unit),395T::WithTimeUnit(time_unit) => IT::WithTimeUnit(time_unit),396#[cfg(feature = "timezones")]397T::ConvertTimeZone(time_zone) => IT::ConvertTimeZone(time_zone),398T::TimeStamp(time_unit) => IT::TimeStamp(time_unit),399T::Truncate => IT::Truncate,400#[cfg(feature = "offset_by")]401T::OffsetBy => IT::OffsetBy,402#[cfg(feature = "month_start")]403T::MonthStart => IT::MonthStart,404#[cfg(feature = "month_end")]405T::MonthEnd => IT::MonthEnd,406#[cfg(feature = "timezones")]407T::BaseUtcOffset => IT::BaseUtcOffset,408#[cfg(feature = "timezones")]409T::DSTOffset => IT::DSTOffset,410T::Round => IT::Round,411T::Replace => IT::Replace,412#[cfg(feature = "timezones")]413T::ReplaceTimeZone(time_zone, non_existent) => {414IT::ReplaceTimeZone(time_zone, non_existent)415},416T::Combine(time_unit) => IT::Combine(time_unit),417T::DatetimeFunction {418time_unit,419time_zone,420} => IT::DatetimeFunction {421time_unit,422time_zone,423},424})425},426#[cfg(feature = "bitwise")]427F::Bitwise(bitwise_function) => I::Bitwise(match bitwise_function {428BitwiseFunction::CountOnes => IRBitwiseFunction::CountOnes,429BitwiseFunction::CountZeros => IRBitwiseFunction::CountZeros,430BitwiseFunction::LeadingOnes => IRBitwiseFunction::LeadingOnes,431BitwiseFunction::LeadingZeros => IRBitwiseFunction::LeadingZeros,432BitwiseFunction::TrailingOnes => IRBitwiseFunction::TrailingOnes,433BitwiseFunction::TrailingZeros => IRBitwiseFunction::TrailingZeros,434BitwiseFunction::And => IRBitwiseFunction::And,435BitwiseFunction::Or => IRBitwiseFunction::Or,436BitwiseFunction::Xor => IRBitwiseFunction::Xor,437}),438F::Boolean(boolean_function) => {439use {BooleanFunction as B, IRBooleanFunction as IB};440I::Boolean(match boolean_function {441B::Any { ignore_nulls } => IB::Any { ignore_nulls },442B::All { ignore_nulls } => IB::All { ignore_nulls },443B::IsNull => IB::IsNull,444B::IsNotNull => IB::IsNotNull,445B::IsFinite => IB::IsFinite,446B::IsInfinite => IB::IsInfinite,447B::IsNan => IB::IsNan,448B::IsNotNan => IB::IsNotNan,449#[cfg(feature = "is_first_distinct")]450B::IsFirstDistinct => IB::IsFirstDistinct,451#[cfg(feature = "is_last_distinct")]452B::IsLastDistinct => IB::IsLastDistinct,453#[cfg(feature = "is_unique")]454B::IsUnique => IB::IsUnique,455#[cfg(feature = "is_unique")]456B::IsDuplicated => IB::IsDuplicated,457#[cfg(feature = "is_between")]458B::IsBetween { closed } => IB::IsBetween { closed },459#[cfg(feature = "is_in")]460B::IsIn { nulls_equal } => IB::IsIn { nulls_equal },461#[cfg(feature = "is_close")]462B::IsClose {463abs_tol,464rel_tol,465nans_equal,466} => IB::IsClose {467abs_tol,468rel_tol,469nans_equal,470},471B::AllHorizontal => {472let Some(fst) = e.first() else {473return Ok((474ctx.arena.add(AExpr::Literal(Scalar::from(true).into())),475format_pl_smallstr!("{}", IB::AllHorizontal),476));477};478479if e.len() == 1 {480return Ok((481AExprBuilder::new_from_node(fst.node())482.cast(DataType::Boolean, ctx.arena)483.node(),484fst.output_name().clone(),485));486}487488// Convert to binary expression as the optimizer understands those.489// Don't exceed 128 expressions as we might stackoverflow.490if e.len() < 128 {491let mut r = AExprBuilder::new_from_node(fst.node());492for expr in &e[1..] {493r = r.logical_and(expr.node(), ctx.arena);494}495return Ok((r.node(), fst.output_name().clone()));496}497498IB::AllHorizontal499},500B::AnyHorizontal => {501// This can be created by col(*).is_null() on empty dataframes.502let Some(fst) = e.first() else {503return Ok((504ctx.arena.add(AExpr::Literal(Scalar::from(false).into())),505format_pl_smallstr!("{}", IB::AnyHorizontal),506));507};508509if e.len() == 1 {510return Ok((511AExprBuilder::new_from_node(fst.node())512.cast(DataType::Boolean, ctx.arena)513.node(),514fst.output_name().clone(),515));516}517518// Convert to binary expression as the optimizer understands those.519// Don't exceed 128 expressions as we might stackoverflow.520if e.len() < 128 {521let mut r = AExprBuilder::new_from_node(fst.node());522for expr in &e[1..] {523r = r.logical_or(expr.node(), ctx.arena);524}525return Ok((r.node(), fst.output_name().clone()));526}527528IB::AnyHorizontal529},530B::Not => IB::Not,531})532},533#[cfg(feature = "business")]534F::Business(business_function) => I::Business(match business_function {535BusinessFunction::BusinessDayCount {536week_mask,537holidays,538} => IRBusinessFunction::BusinessDayCount {539week_mask,540holidays,541},542BusinessFunction::AddBusinessDay {543week_mask,544holidays,545roll,546} => IRBusinessFunction::AddBusinessDay {547week_mask,548holidays,549roll,550},551BusinessFunction::IsBusinessDay {552week_mask,553holidays,554} => IRBusinessFunction::IsBusinessDay {555week_mask,556holidays,557},558}),559#[cfg(feature = "abs")]560F::Abs => I::Abs,561F::Negate => I::Negate,562#[cfg(feature = "hist")]563F::Hist {564bin_count,565include_category,566include_breakpoint,567} => I::Hist {568bin_count,569include_category,570include_breakpoint,571},572F::NullCount => I::NullCount,573F::Pow(pow_function) => I::Pow(match pow_function {574PowFunction::Generic => IRPowFunction::Generic,575PowFunction::Sqrt => IRPowFunction::Sqrt,576PowFunction::Cbrt => IRPowFunction::Cbrt,577}),578#[cfg(feature = "row_hash")]579F::Hash(s0, s1, s2, s3) => I::Hash(s0, s1, s2, s3),580#[cfg(feature = "arg_where")]581F::ArgWhere => I::ArgWhere,582#[cfg(feature = "index_of")]583F::IndexOf => I::IndexOf,584#[cfg(feature = "search_sorted")]585F::SearchSorted { side, descending } => I::SearchSorted { side, descending },586#[cfg(feature = "range")]587F::Range(range_function) => I::Range(match range_function {588RangeFunction::IntRange { step, dtype } => {589let dtype = dtype.into_datatype(ctx.schema)?;590polars_ensure!(e[0].is_scalar(ctx.arena), ShapeMismatch: "non-scalar start passed to `int_range`");591polars_ensure!(e[1].is_scalar(ctx.arena), ShapeMismatch: "non-scalar stop passed to `int_range`");592polars_ensure!(dtype.is_integer(), SchemaMismatch: "non-integer `dtype` passed to `int_range`: '{dtype}'");593IRRangeFunction::IntRange { step, dtype }594},595RangeFunction::IntRanges { dtype } => {596let dtype = dtype.into_datatype(ctx.schema)?;597polars_ensure!(dtype.is_integer(), SchemaMismatch: "non-integer `dtype` passed to `int_ranges`: '{dtype}'");598IRRangeFunction::IntRanges { dtype }599},600RangeFunction::LinearSpace { closed } => {601polars_ensure!(e[0].is_scalar(ctx.arena), ShapeMismatch: "non-scalar start passed to `linear_space`");602polars_ensure!(e[1].is_scalar(ctx.arena), ShapeMismatch: "non-scalar end passed to `linear_space`");603polars_ensure!(e[2].is_scalar(ctx.arena), ShapeMismatch: "non-scalar num_samples passed to `linear_space`");604IRRangeFunction::LinearSpace { closed }605},606RangeFunction::LinearSpaces {607closed,608array_width,609} => IRRangeFunction::LinearSpaces {610closed,611array_width,612},613#[cfg(all(feature = "range", feature = "dtype-date"))]614RangeFunction::DateRange {615interval,616closed,617arg_type,618} => {619use DateRangeArgs::*;620let arg_names = match arg_type {621StartEndSamples => vec!["start", "end", "num_samples"],622StartEndInterval => vec!["start", "end"],623StartIntervalSamples => vec!["start", "num_samples"],624EndIntervalSamples => vec!["end", "num_samples"],625};626for (idx, &name) in arg_names.iter().enumerate() {627polars_ensure!(e[idx].is_scalar(ctx.arena), ShapeMismatch: "non-scalar {name} passed to `date_range`");628}629IRRangeFunction::DateRange {630interval,631closed,632arg_type,633}634},635#[cfg(all(feature = "range", feature = "dtype-date"))]636RangeFunction::DateRanges {637interval,638closed,639arg_type,640} => IRRangeFunction::DateRanges {641interval,642closed,643arg_type,644},645#[cfg(all(feature = "range", feature = "dtype-datetime"))]646RangeFunction::DatetimeRange {647interval,648closed,649time_unit,650time_zone,651arg_type,652} => {653use DateRangeArgs::*;654let arg_names = match arg_type {655StartEndSamples => vec!["start", "end", "num_samples"],656StartEndInterval => vec!["start", "end"],657StartIntervalSamples => vec!["start", "num_samples"],658EndIntervalSamples => vec!["end", "num_samples"],659};660for (idx, &name) in arg_names.iter().enumerate() {661polars_ensure!(e[idx].is_scalar(ctx.arena), ShapeMismatch: "non-scalar {name} passed to `datetime_range`");662}663IRRangeFunction::DatetimeRange {664interval,665closed,666time_unit,667time_zone,668arg_type,669}670},671#[cfg(all(feature = "range", feature = "dtype-datetime"))]672RangeFunction::DatetimeRanges {673interval,674closed,675time_unit,676time_zone,677arg_type,678} => IRRangeFunction::DatetimeRanges {679interval,680closed,681time_unit,682time_zone,683arg_type,684},685#[cfg(all(feature = "range", feature = "dtype-time"))]686RangeFunction::TimeRange { interval, closed } => {687polars_ensure!(e[0].is_scalar(ctx.arena), ShapeMismatch: "non-scalar start passed to `time_range`");688polars_ensure!(e[1].is_scalar(ctx.arena), ShapeMismatch: "non-scalar end passed to `time_range`");689IRRangeFunction::TimeRange { interval, closed }690},691#[cfg(all(feature = "range", feature = "dtype-time"))]692RangeFunction::TimeRanges { interval, closed } => {693IRRangeFunction::TimeRanges { interval, closed }694},695}),696#[cfg(feature = "trigonometry")]697F::Trigonometry(trigonometric_function) => {698use {IRTrigonometricFunction as IT, TrigonometricFunction as T};699I::Trigonometry(match trigonometric_function {700T::Cos => IT::Cos,701T::Cot => IT::Cot,702T::Sin => IT::Sin,703T::Tan => IT::Tan,704T::ArcCos => IT::ArcCos,705T::ArcSin => IT::ArcSin,706T::ArcTan => IT::ArcTan,707T::Cosh => IT::Cosh,708T::Sinh => IT::Sinh,709T::Tanh => IT::Tanh,710T::ArcCosh => IT::ArcCosh,711T::ArcSinh => IT::ArcSinh,712T::ArcTanh => IT::ArcTanh,713T::Degrees => IT::Degrees,714T::Radians => IT::Radians,715})716},717#[cfg(feature = "trigonometry")]718F::Atan2 => I::Atan2,719#[cfg(feature = "sign")]720F::Sign => I::Sign,721F::FillNull => I::FillNull,722F::FillNullWithStrategy(fill_null_strategy) => I::FillNullWithStrategy(fill_null_strategy),723#[cfg(feature = "rolling_window")]724F::RollingExpr { function, options } => {725use RollingFunction as R;726use aexpr::IRRollingFunction as IR;727728I::RollingExpr {729function: match function {730R::Min => IR::Min,731R::Max => IR::Max,732R::Mean => IR::Mean,733R::Sum => IR::Sum,734R::Quantile => IR::Quantile,735R::Var => IR::Var,736R::Std => IR::Std,737R::Rank => IR::Rank,738#[cfg(feature = "moment")]739R::Skew => IR::Skew,740#[cfg(feature = "moment")]741R::Kurtosis => IR::Kurtosis,742#[cfg(feature = "cov")]743R::CorrCov {744corr_cov_options,745is_corr,746} => IR::CorrCov {747corr_cov_options,748is_corr,749},750R::Map(f) => IR::Map(f),751},752options,753}754},755#[cfg(feature = "rolling_window_by")]756F::RollingExprBy {757function_by,758options,759} => {760use RollingFunctionBy as R;761use aexpr::IRRollingFunctionBy as IR;762763I::RollingExprBy {764function_by: match function_by {765R::MinBy => IR::MinBy,766R::MaxBy => IR::MaxBy,767R::MeanBy => IR::MeanBy,768R::SumBy => IR::SumBy,769R::QuantileBy => IR::QuantileBy,770R::VarBy => IR::VarBy,771R::StdBy => IR::StdBy,772R::RankBy => IR::RankBy,773},774options,775}776},777F::Rechunk => I::Rechunk,778F::Append { upcast } => I::Append { upcast },779F::ShiftAndFill => {780polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");781polars_ensure!(&e[2].is_scalar(ctx.arena), ShapeMismatch: "'fill_value' must be a scalar value");782I::ShiftAndFill783},784F::Shift => {785polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");786I::Shift787},788F::DropNans => I::DropNans,789F::DropNulls => I::DropNulls,790#[cfg(feature = "mode")]791F::Mode { maintain_order } => I::Mode { maintain_order },792#[cfg(feature = "moment")]793F::Skew(v) => I::Skew(v),794#[cfg(feature = "moment")]795F::Kurtosis(l, r) => I::Kurtosis(l, r),796#[cfg(feature = "dtype-array")]797F::Reshape(reshape_dimensions) => I::Reshape(reshape_dimensions),798#[cfg(feature = "repeat_by")]799F::RepeatBy => I::RepeatBy,800F::ArgUnique => I::ArgUnique,801F::ArgMin => I::ArgMin,802F::ArgMax => I::ArgMax,803F::ArgSort {804descending,805nulls_last,806} => I::ArgSort {807descending,808nulls_last,809},810F::MinBy => I::MinBy,811F::MaxBy => I::MaxBy,812F::Product => I::Product,813#[cfg(feature = "rank")]814F::Rank { options, seed } => I::Rank { options, seed },815F::Repeat => {816polars_ensure!(&e[0].is_scalar(ctx.arena), ShapeMismatch: "'value' must be a scalar value");817polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");818I::Repeat819},820#[cfg(feature = "round_series")]821F::Clip { has_min, has_max } => I::Clip { has_min, has_max },822#[cfg(feature = "dtype-struct")]823F::AsStruct => I::AsStruct,824#[cfg(feature = "top_k")]825F::TopK { descending } => I::TopK { descending },826#[cfg(feature = "top_k")]827F::TopKBy { descending } => I::TopKBy { descending },828#[cfg(feature = "cum_agg")]829F::CumCount { reverse } => I::CumCount { reverse },830#[cfg(feature = "cum_agg")]831F::CumSum { reverse } => I::CumSum { reverse },832#[cfg(feature = "cum_agg")]833F::CumProd { reverse } => I::CumProd { reverse },834#[cfg(feature = "cum_agg")]835F::CumMin { reverse } => I::CumMin { reverse },836#[cfg(feature = "cum_agg")]837F::CumMax { reverse } => I::CumMax { reverse },838F::Reverse => I::Reverse,839#[cfg(feature = "dtype-struct")]840F::ValueCounts {841sort,842parallel,843name,844normalize,845} => I::ValueCounts {846sort,847parallel,848name,849normalize,850},851#[cfg(feature = "unique_counts")]852F::UniqueCounts => I::UniqueCounts,853#[cfg(feature = "approx_unique")]854F::ApproxNUnique => I::ApproxNUnique,855F::Coalesce => I::Coalesce,856#[cfg(feature = "diff")]857F::Diff(n) => {858polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");859I::Diff(n)860},861#[cfg(feature = "pct_change")]862F::PctChange => I::PctChange,863#[cfg(feature = "interpolate")]864F::Interpolate(interpolation_method) => I::Interpolate(interpolation_method),865#[cfg(feature = "interpolate_by")]866F::InterpolateBy => I::InterpolateBy,867#[cfg(feature = "log")]868F::Entropy { base, normalize } => I::Entropy { base, normalize },869#[cfg(feature = "log")]870F::Log => I::Log,871#[cfg(feature = "log")]872F::Log1p => I::Log1p,873#[cfg(feature = "log")]874F::Exp => I::Exp,875F::Unique(v) => I::Unique(v),876#[cfg(feature = "round_series")]877F::Round { decimals, mode } => I::Round { decimals, mode },878#[cfg(feature = "round_series")]879F::RoundSF { digits } => I::RoundSF { digits },880#[cfg(feature = "round_series")]881F::Truncate { decimals } => I::Truncate { decimals },882#[cfg(feature = "round_series")]883F::Floor => I::Floor,884#[cfg(feature = "round_series")]885F::Ceil => I::Ceil,886F::UpperBound => {887let field = e[0].field(ctx.schema, ctx.arena)?;888return Ok((889ctx.arena890.add(AExpr::Literal(field.dtype.to_physical().max()?.into())),891field.name,892));893},894F::LowerBound => {895let field = e[0].field(ctx.schema, ctx.arena)?;896return Ok((897ctx.arena898.add(AExpr::Literal(field.dtype.to_physical().min()?.into())),899field.name,900));901},902F::ConcatExpr(v) => I::ConcatExpr(v),903#[cfg(feature = "cov")]904F::Correlation { method } => {905use {CorrelationMethod as C, IRCorrelationMethod as IC};906I::Correlation {907method: match method {908C::Pearson => IC::Pearson,909#[cfg(all(feature = "rank", feature = "propagate_nans"))]910C::SpearmanRank(v) => IC::SpearmanRank(v),911C::Covariance(v) => IC::Covariance(v),912},913}914},915#[cfg(feature = "peaks")]916F::PeakMin => I::PeakMin,917#[cfg(feature = "peaks")]918F::PeakMax => I::PeakMax,919#[cfg(feature = "cutqcut")]920F::Cut {921breaks,922labels,923left_closed,924include_breaks,925} => I::Cut {926breaks,927labels,928left_closed,929include_breaks,930},931#[cfg(feature = "cutqcut")]932F::QCut {933probs,934labels,935left_closed,936allow_duplicates,937include_breaks,938} => I::QCut {939probs,940labels,941left_closed,942allow_duplicates,943include_breaks,944},945#[cfg(feature = "rle")]946F::RLE => I::RLE,947#[cfg(feature = "rle")]948F::RLEID => I::RLEID,949F::ToPhysical => I::ToPhysical,950#[cfg(feature = "random")]951F::Random { method, seed } => {952use {IRRandomMethod as IR, RandomMethod as R};953I::Random {954method: match method {955R::Shuffle => IR::Shuffle,956R::Sample {957is_fraction,958with_replacement,959shuffle,960} => IR::Sample {961is_fraction,962with_replacement,963shuffle,964},965},966seed,967}968},969F::SetSortedFlag(is_sorted) => I::SetSortedFlag(is_sorted),970#[cfg(feature = "ffi_plugin")]971F::FfiPlugin {972flags,973lib,974symbol,975kwargs,976} => I::FfiPlugin {977flags,978lib,979symbol,980kwargs,981},982983F::FoldHorizontal {984callback,985returns_scalar,986return_dtype,987} => I::FoldHorizontal {988callback,989returns_scalar,990return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,991},992F::ReduceHorizontal {993callback,994returns_scalar,995return_dtype,996} => I::ReduceHorizontal {997callback,998returns_scalar,999return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,1000},1001#[cfg(feature = "dtype-struct")]1002F::CumReduceHorizontal {1003callback,1004returns_scalar,1005return_dtype,1006} => I::CumReduceHorizontal {1007callback,1008returns_scalar,1009return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,1010},1011#[cfg(feature = "dtype-struct")]1012F::CumFoldHorizontal {1013callback,1014returns_scalar,1015return_dtype,1016include_init,1017} => I::CumFoldHorizontal {1018callback,1019returns_scalar,1020return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,1021include_init,1022},10231024F::MaxHorizontal => I::MaxHorizontal,1025F::MinHorizontal => I::MinHorizontal,1026F::SumHorizontal { ignore_nulls } => I::SumHorizontal { ignore_nulls },1027F::MeanHorizontal { ignore_nulls } => I::MeanHorizontal { ignore_nulls },1028#[cfg(feature = "ewma")]1029F::EwmMean { options } => I::EwmMean { options },1030#[cfg(feature = "ewma_by")]1031F::EwmMeanBy { half_life } => I::EwmMeanBy { half_life },1032#[cfg(feature = "ewma")]1033F::EwmStd { options } => I::EwmStd { options },1034#[cfg(feature = "ewma")]1035F::EwmVar { options } => I::EwmVar { options },1036#[cfg(feature = "replace")]1037F::Replace => I::Replace,1038#[cfg(feature = "replace")]1039F::ReplaceStrict { return_dtype } => I::ReplaceStrict {1040return_dtype: match return_dtype {1041Some(dtype) => Some(dtype.into_datatype(ctx.schema)?),1042None => None,1043},1044},1045F::GatherEvery { n, offset } => I::GatherEvery { n, offset },1046#[cfg(feature = "reinterpret")]1047F::Reinterpret(v) => I::Reinterpret(v),1048F::ExtendConstant => {1049polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'value' must be a scalar value");1050polars_ensure!(&e[2].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");1051I::ExtendConstant1052},10531054F::RowEncode(v) => {1055let dts = e1056.iter()1057.map(|e| Ok(e.dtype(ctx.schema, ctx.arena)?.clone()))1058.collect::<PolarsResult<Vec<_>>>()?;1059I::RowEncode(dts, v)1060},1061#[cfg(feature = "dtype-struct")]1062F::RowDecode(fs, v) => I::RowDecode(1063fs.into_iter()1064.map(|(name, dt_expr)| Ok(Field::new(name, dt_expr.into_datatype(ctx.schema)?)))1065.collect::<PolarsResult<Vec<_>>>()?,1066v,1067),1068};10691070let mut options = ir_function.function_options();1071if set_elementwise {1072options.set_elementwise();1073}10741075// Handles special case functions like `struct.field`.1076let output_name = match ir_function.output_name().and_then(|v| v.into_inner()) {1077Some(name) => name,1078None if e.is_empty() => format_pl_smallstr!("{}", &ir_function),1079None => e[0].output_name().clone(),1080};10811082let ae_function = AExpr::Function {1083input: e,1084function: ir_function,1085options,1086};1087Ok((ctx.arena.add(ae_function), output_name))1088}108910901091