Path: blob/main/crates/polars-plan/src/plans/conversion/dsl_to_ir/functions.rs
7889 views
use arrow::legacy::error::PolarsResult;1use polars_utils::arena::Node;2use polars_utils::format_pl_smallstr;3use polars_utils::option::OptionTry;45use super::expr_to_ir::ExprToIRContext;6use super::*;7use crate::constants::get_literal_name;8use crate::dsl::{Expr, FunctionExpr};9use crate::plans::conversion::dsl_to_ir::expr_to_ir::{to_expr_ir, to_expr_irs};10use crate::plans::{AExpr, IRFunctionExpr};1112pub(super) fn convert_functions(13input: Vec<Expr>,14function: FunctionExpr,15ctx: &mut ExprToIRContext,16) -> PolarsResult<(Node, PlSmallStr)> {17use {FunctionExpr as F, IRFunctionExpr as I};1819#[cfg(feature = "dtype-struct")]20if matches!(21function,22FunctionExpr::StructExpr(StructFunction::WithFields)23) {24let mut input = input.into_iter();25let struct_input = to_expr_ir(input.next().unwrap(), ctx)?;26let dtype = struct_input.to_expr(ctx.arena).to_field(ctx.schema)?.dtype;27let DataType::Struct(fields) = &dtype else {28polars_bail!(op = "struct.with_fields", dtype);29};3031let struct_name = struct_input.output_name().clone();32let struct_node = struct_input.node();33let struct_schema = Schema::from_iter(fields.iter().cloned());3435let mut e = Vec::with_capacity(input.len());36e.push(struct_input);3738let prev = ctx.with_fields.replace((struct_node, struct_schema));39for i in input {40e.push(to_expr_ir(i, ctx)?);41}42ctx.with_fields = prev;4344let function = IRFunctionExpr::StructExpr(IRStructFunction::WithFields);45let options = function.function_options();46let out = ctx.arena.add(AExpr::Function {47input: e,48function,49options,50});5152return Ok((out, struct_name));53}5455let input_is_empty = input.is_empty();5657// Converts inputs58let e = to_expr_irs(input, ctx)?;59let mut set_elementwise = false;6061// Return before converting inputs62let ir_function = match function {63#[cfg(feature = "dtype-array")]64F::ArrayExpr(array_function) => {65use {ArrayFunction as A, IRArrayFunction as IA};66I::ArrayExpr(match array_function {67A::Length => IA::Length,68A::Min => IA::Min,69A::Max => IA::Max,70A::Sum => IA::Sum,71A::ToList => IA::ToList,72A::Unique(stable) => IA::Unique(stable),73A::NUnique => IA::NUnique,74A::Std(v) => IA::Std(v),75A::Var(v) => IA::Var(v),76A::Mean => IA::Mean,77A::Median => IA::Median,78#[cfg(feature = "array_any_all")]79A::Any => IA::Any,80#[cfg(feature = "array_any_all")]81A::All => IA::All,82A::Sort(sort_options) => IA::Sort(sort_options),83A::Reverse => IA::Reverse,84A::ArgMin => IA::ArgMin,85A::ArgMax => IA::ArgMax,86A::Get(v) => IA::Get(v),87A::Join(v) => IA::Join(v),88#[cfg(feature = "is_in")]89A::Contains { nulls_equal } => IA::Contains { nulls_equal },90#[cfg(feature = "array_count")]91A::CountMatches => IA::CountMatches,92A::Shift => IA::Shift,93A::Explode(options) => IA::Explode(options),94A::Concat => IA::Concat,95A::Slice(offset, length) => IA::Slice(offset, length),96#[cfg(feature = "array_to_struct")]97A::ToStruct(ng) => IA::ToStruct(ng),98})99},100F::BinaryExpr(binary_function) => {101use {BinaryFunction as B, IRBinaryFunction as IB};102I::BinaryExpr(match binary_function {103B::Contains => IB::Contains,104B::StartsWith => IB::StartsWith,105B::EndsWith => IB::EndsWith,106#[cfg(feature = "binary_encoding")]107B::HexDecode(v) => IB::HexDecode(v),108#[cfg(feature = "binary_encoding")]109B::HexEncode => IB::HexEncode,110#[cfg(feature = "binary_encoding")]111B::Base64Decode(v) => IB::Base64Decode(v),112#[cfg(feature = "binary_encoding")]113B::Base64Encode => IB::Base64Encode,114B::Size => IB::Size,115#[cfg(feature = "binary_encoding")]116B::Reinterpret(dtype_expr, v) => {117let dtype = dtype_expr.into_datatype(ctx.schema)?;118let can_reinterpret_to =119|dt: &DataType| dt.is_primitive_numeric() || dt.is_temporal();120polars_ensure!(121can_reinterpret_to(&dtype) || (122dtype.is_array() && dtype.inner_dtype().map(can_reinterpret_to) == Some(true)123),124InvalidOperation:125"cannot reinterpret binary to dtype {:?}. Only numeric or temporal dtype, or Arrays of these, are supported. Hint: To reinterpret to a nested Array, first reinterpret to a linear Array, and then use reshape",126dtype127);128IB::Reinterpret(dtype, v)129},130B::Slice => IB::Slice,131B::Head => IB::Head,132B::Tail => IB::Tail,133})134},135#[cfg(feature = "dtype-categorical")]136F::Categorical(categorical_function) => {137use {CategoricalFunction as C, IRCategoricalFunction as IC};138I::Categorical(match categorical_function {139C::GetCategories => IC::GetCategories,140#[cfg(feature = "strings")]141C::LenBytes => IC::LenBytes,142#[cfg(feature = "strings")]143C::LenChars => IC::LenChars,144#[cfg(feature = "strings")]145C::StartsWith(v) => IC::StartsWith(v),146#[cfg(feature = "strings")]147C::EndsWith(v) => IC::EndsWith(v),148#[cfg(feature = "strings")]149C::Slice(s, e) => IC::Slice(s, e),150})151},152#[cfg(feature = "dtype-extension")]153F::Extension(extension_function) => {154use {ExtensionFunction as E, IRExtensionFunction as IE};155I::Extension(match extension_function {156E::To(dtype) => {157let concrete_dtype = dtype.into_datatype(ctx.schema)?;158polars_ensure!(matches!(concrete_dtype, DataType::Extension(_, _)),159InvalidOperation: "ext.to() requires an Extension dtype, got {concrete_dtype:?}"160);161IE::To(concrete_dtype)162},163E::Storage => IE::Storage,164})165},166F::ListExpr(list_function) => {167use {IRListFunction as IL, ListFunction as L};168I::ListExpr(match list_function {169L::Concat => IL::Concat,170#[cfg(feature = "is_in")]171L::Contains { nulls_equal } => IL::Contains { nulls_equal },172#[cfg(feature = "list_drop_nulls")]173L::DropNulls => IL::DropNulls,174#[cfg(feature = "list_sample")]175L::Sample {176is_fraction,177with_replacement,178shuffle,179seed,180} => IL::Sample {181is_fraction,182with_replacement,183shuffle,184seed,185},186L::Slice => IL::Slice,187L::Shift => IL::Shift,188L::Get(v) => IL::Get(v),189#[cfg(feature = "list_gather")]190L::Gather(v) => IL::Gather(v),191#[cfg(feature = "list_gather")]192L::GatherEvery => IL::GatherEvery,193#[cfg(feature = "list_count")]194L::CountMatches => IL::CountMatches,195L::Sum => IL::Sum,196L::Length => IL::Length,197L::Max => IL::Max,198L::Min => IL::Min,199L::Mean => IL::Mean,200L::Median => IL::Median,201L::Std(v) => IL::Std(v),202L::Var(v) => IL::Var(v),203L::ArgMin => IL::ArgMin,204L::ArgMax => IL::ArgMax,205#[cfg(feature = "diff")]206L::Diff { n, null_behavior } => IL::Diff { n, null_behavior },207L::Sort(sort_options) => IL::Sort(sort_options),208L::Reverse => IL::Reverse,209L::Unique(v) => IL::Unique(v),210L::NUnique => IL::NUnique,211#[cfg(feature = "list_sets")]212L::SetOperation(set_operation) => IL::SetOperation(set_operation),213#[cfg(feature = "list_any_all")]214L::Any => IL::Any,215#[cfg(feature = "list_any_all")]216L::All => IL::All,217L::Join(v) => IL::Join(v),218#[cfg(feature = "dtype-array")]219L::ToArray(v) => IL::ToArray(v),220#[cfg(feature = "list_to_struct")]221L::ToStruct(list_to_struct_args) => IL::ToStruct(list_to_struct_args),222})223},224#[cfg(feature = "strings")]225F::StringExpr(string_function) => {226use {IRStringFunction as IS, StringFunction as S};227I::StringExpr(match string_function {228S::Format { format, insertions } => {229if input_is_empty {230polars_ensure!(231insertions.is_empty(),232ComputeError: "StringFormat didn't get any inputs, format: \"{}\"",233format234);235236let out = ctx237.arena238.add(AExpr::Literal(LiteralValue::Scalar(Scalar::from(format))));239240return Ok((out, get_literal_name()));241} else {242IS::Format { format, insertions }243}244},245#[cfg(feature = "concat_str")]246S::ConcatHorizontal {247delimiter,248ignore_nulls,249} => IS::ConcatHorizontal {250delimiter,251ignore_nulls,252},253#[cfg(feature = "concat_str")]254S::ConcatVertical {255delimiter,256ignore_nulls,257} => IS::ConcatVertical {258delimiter,259ignore_nulls,260},261#[cfg(feature = "regex")]262S::Contains { literal, strict } => IS::Contains { literal, strict },263S::CountMatches(v) => IS::CountMatches(v),264S::EndsWith => IS::EndsWith,265S::Extract(v) => IS::Extract(v),266S::ExtractAll => IS::ExtractAll,267#[cfg(feature = "extract_groups")]268S::ExtractGroups { dtype, pat } => IS::ExtractGroups { dtype, pat },269#[cfg(feature = "regex")]270S::Find { literal, strict } => IS::Find { literal, strict },271#[cfg(feature = "string_to_integer")]272S::ToInteger { dtype, strict } => IS::ToInteger { dtype, strict },273S::LenBytes => IS::LenBytes,274S::LenChars => IS::LenChars,275S::Lowercase => IS::Lowercase,276#[cfg(feature = "extract_jsonpath")]277S::JsonDecode(dtype) => IS::JsonDecode(dtype.into_datatype(ctx.schema)?),278#[cfg(feature = "extract_jsonpath")]279S::JsonPathMatch => IS::JsonPathMatch,280#[cfg(feature = "regex")]281S::Replace { n, literal } => IS::Replace { n, literal },282#[cfg(feature = "string_normalize")]283S::Normalize { form } => IS::Normalize { form },284#[cfg(feature = "string_reverse")]285S::Reverse => IS::Reverse,286#[cfg(feature = "string_pad")]287S::PadStart { fill_char } => IS::PadStart { fill_char },288#[cfg(feature = "string_pad")]289S::PadEnd { fill_char } => IS::PadEnd { fill_char },290S::Slice => IS::Slice,291S::Head => IS::Head,292S::Tail => IS::Tail,293#[cfg(feature = "string_encoding")]294S::HexEncode => IS::HexEncode,295#[cfg(feature = "binary_encoding")]296S::HexDecode(v) => IS::HexDecode(v),297#[cfg(feature = "string_encoding")]298S::Base64Encode => IS::Base64Encode,299#[cfg(feature = "binary_encoding")]300S::Base64Decode(v) => IS::Base64Decode(v),301S::StartsWith => IS::StartsWith,302S::StripChars => IS::StripChars,303S::StripCharsStart => IS::StripCharsStart,304S::StripCharsEnd => IS::StripCharsEnd,305S::StripPrefix => IS::StripPrefix,306S::StripSuffix => IS::StripSuffix,307#[cfg(feature = "dtype-struct")]308S::SplitExact { n, inclusive } => IS::SplitExact { n, inclusive },309#[cfg(feature = "dtype-struct")]310S::SplitN(v) => IS::SplitN(v),311#[cfg(feature = "temporal")]312S::Strptime(data_type, strptime_options) => {313let is_column_independent = is_column_independent_aexpr(e[0].node(), ctx.arena);314set_elementwise = is_column_independent;315let dtype = data_type.into_datatype(ctx.schema)?;316polars_ensure!(317matches!(dtype,318DataType::Date |319DataType::Datetime(_, _) |320DataType::Time321),322InvalidOperation: "`strptime` expects a `date`, `datetime` or `time` got {dtype}"323);324IS::Strptime(dtype, strptime_options)325},326S::Split(v) => IS::Split(v),327#[cfg(feature = "dtype-decimal")]328S::ToDecimal { scale } => IS::ToDecimal { scale },329#[cfg(feature = "nightly")]330S::Titlecase => IS::Titlecase,331S::Uppercase => IS::Uppercase,332#[cfg(feature = "string_pad")]333S::ZFill => IS::ZFill,334#[cfg(feature = "find_many")]335S::ContainsAny {336ascii_case_insensitive,337} => IS::ContainsAny {338ascii_case_insensitive,339},340#[cfg(feature = "find_many")]341S::ReplaceMany {342ascii_case_insensitive,343leftmost,344} => IS::ReplaceMany {345ascii_case_insensitive,346leftmost,347},348#[cfg(feature = "find_many")]349S::ExtractMany {350ascii_case_insensitive,351overlapping,352leftmost,353} => IS::ExtractMany {354ascii_case_insensitive,355overlapping,356leftmost,357},358#[cfg(feature = "find_many")]359S::FindMany {360ascii_case_insensitive,361overlapping,362leftmost,363} => IS::FindMany {364ascii_case_insensitive,365overlapping,366leftmost,367},368#[cfg(feature = "regex")]369S::EscapeRegex => IS::EscapeRegex,370})371},372#[cfg(feature = "dtype-struct")]373F::StructExpr(struct_function) => {374use {IRStructFunction as IS, StructFunction as S};375I::StructExpr(match struct_function {376S::FieldByName(pl_small_str) => IS::FieldByName(pl_small_str),377S::RenameFields(pl_small_strs) => IS::RenameFields(pl_small_strs),378S::PrefixFields(pl_small_str) => IS::PrefixFields(pl_small_str),379S::SuffixFields(pl_small_str) => IS::SuffixFields(pl_small_str),380S::SelectFields(_) => unreachable!("handled by expression expansion"),381#[cfg(feature = "json")]382S::JsonEncode => IS::JsonEncode,383S::WithFields => unreachable!("handled before"),384S::MapFieldNames(f) => IS::MapFieldNames(f),385})386},387#[cfg(feature = "temporal")]388F::TemporalExpr(temporal_function) => {389use {IRTemporalFunction as IT, TemporalFunction as T};390I::TemporalExpr(match temporal_function {391T::Millennium => IT::Millennium,392T::Century => IT::Century,393T::Year => IT::Year,394T::IsLeapYear => IT::IsLeapYear,395T::IsoYear => IT::IsoYear,396T::Quarter => IT::Quarter,397T::Month => IT::Month,398T::DaysInMonth => IT::DaysInMonth,399T::Week => IT::Week,400T::WeekDay => IT::WeekDay,401T::Day => IT::Day,402T::OrdinalDay => IT::OrdinalDay,403T::Time => IT::Time,404T::Date => IT::Date,405T::Datetime => IT::Datetime,406#[cfg(feature = "dtype-duration")]407T::Duration(time_unit) => IT::Duration(time_unit),408T::Hour => IT::Hour,409T::Minute => IT::Minute,410T::Second => IT::Second,411T::Millisecond => IT::Millisecond,412T::Microsecond => IT::Microsecond,413T::Nanosecond => IT::Nanosecond,414#[cfg(feature = "dtype-duration")]415T::TotalDays { fractional } => IT::TotalDays { fractional },416#[cfg(feature = "dtype-duration")]417T::TotalHours { fractional } => IT::TotalHours { fractional },418#[cfg(feature = "dtype-duration")]419T::TotalMinutes { fractional } => IT::TotalMinutes { fractional },420#[cfg(feature = "dtype-duration")]421T::TotalSeconds { fractional } => IT::TotalSeconds { fractional },422#[cfg(feature = "dtype-duration")]423T::TotalMilliseconds { fractional } => IT::TotalMilliseconds { fractional },424#[cfg(feature = "dtype-duration")]425T::TotalMicroseconds { fractional } => IT::TotalMicroseconds { fractional },426#[cfg(feature = "dtype-duration")]427T::TotalNanoseconds { fractional } => IT::TotalNanoseconds { fractional },428T::ToString(v) => IT::ToString(v),429T::CastTimeUnit(time_unit) => IT::CastTimeUnit(time_unit),430T::WithTimeUnit(time_unit) => IT::WithTimeUnit(time_unit),431#[cfg(feature = "timezones")]432T::ConvertTimeZone(time_zone) => IT::ConvertTimeZone(time_zone),433T::TimeStamp(time_unit) => IT::TimeStamp(time_unit),434T::Truncate => IT::Truncate,435#[cfg(feature = "offset_by")]436T::OffsetBy => IT::OffsetBy,437#[cfg(feature = "month_start")]438T::MonthStart => IT::MonthStart,439#[cfg(feature = "month_end")]440T::MonthEnd => IT::MonthEnd,441#[cfg(feature = "timezones")]442T::BaseUtcOffset => IT::BaseUtcOffset,443#[cfg(feature = "timezones")]444T::DSTOffset => IT::DSTOffset,445T::Round => IT::Round,446T::Replace => IT::Replace,447#[cfg(feature = "timezones")]448T::ReplaceTimeZone(time_zone, non_existent) => {449IT::ReplaceTimeZone(time_zone, non_existent)450},451T::Combine(time_unit) => IT::Combine(time_unit),452T::DatetimeFunction {453time_unit,454time_zone,455} => IT::DatetimeFunction {456time_unit,457time_zone,458},459})460},461#[cfg(feature = "bitwise")]462F::Bitwise(bitwise_function) => I::Bitwise(match bitwise_function {463BitwiseFunction::CountOnes => IRBitwiseFunction::CountOnes,464BitwiseFunction::CountZeros => IRBitwiseFunction::CountZeros,465BitwiseFunction::LeadingOnes => IRBitwiseFunction::LeadingOnes,466BitwiseFunction::LeadingZeros => IRBitwiseFunction::LeadingZeros,467BitwiseFunction::TrailingOnes => IRBitwiseFunction::TrailingOnes,468BitwiseFunction::TrailingZeros => IRBitwiseFunction::TrailingZeros,469BitwiseFunction::And => IRBitwiseFunction::And,470BitwiseFunction::Or => IRBitwiseFunction::Or,471BitwiseFunction::Xor => IRBitwiseFunction::Xor,472}),473F::Boolean(boolean_function) => {474use {BooleanFunction as B, IRBooleanFunction as IB};475I::Boolean(match boolean_function {476B::Any { ignore_nulls } => IB::Any { ignore_nulls },477B::All { ignore_nulls } => IB::All { ignore_nulls },478B::IsNull => IB::IsNull,479B::IsNotNull => IB::IsNotNull,480B::IsFinite => IB::IsFinite,481B::IsInfinite => IB::IsInfinite,482B::IsNan => IB::IsNan,483B::IsNotNan => IB::IsNotNan,484#[cfg(feature = "is_first_distinct")]485B::IsFirstDistinct => IB::IsFirstDistinct,486#[cfg(feature = "is_last_distinct")]487B::IsLastDistinct => IB::IsLastDistinct,488#[cfg(feature = "is_unique")]489B::IsUnique => IB::IsUnique,490#[cfg(feature = "is_unique")]491B::IsDuplicated => IB::IsDuplicated,492#[cfg(feature = "is_between")]493B::IsBetween { closed } => IB::IsBetween { closed },494#[cfg(feature = "is_in")]495B::IsIn { nulls_equal } => IB::IsIn { nulls_equal },496#[cfg(feature = "is_close")]497B::IsClose {498abs_tol,499rel_tol,500nans_equal,501} => IB::IsClose {502abs_tol,503rel_tol,504nans_equal,505},506B::AllHorizontal => {507let Some(fst) = e.first() else {508return Ok((509ctx.arena.add(AExpr::Literal(Scalar::from(true).into())),510format_pl_smallstr!("{}", IB::AllHorizontal),511));512};513514if e.len() == 1 {515return Ok((516AExprBuilder::new_from_node(fst.node())517.cast(DataType::Boolean, ctx.arena)518.node(),519fst.output_name().clone(),520));521}522523// Convert to binary expression as the optimizer understands those.524// Don't exceed 128 expressions as we might stackoverflow.525if e.len() < 128 {526let mut r = AExprBuilder::new_from_node(fst.node());527for expr in &e[1..] {528r = r.logical_and(expr.node(), ctx.arena);529}530return Ok((r.node(), fst.output_name().clone()));531}532533IB::AllHorizontal534},535B::AnyHorizontal => {536// This can be created by col(*).is_null() on empty dataframes.537let Some(fst) = e.first() else {538return Ok((539ctx.arena.add(AExpr::Literal(Scalar::from(false).into())),540format_pl_smallstr!("{}", IB::AnyHorizontal),541));542};543544if e.len() == 1 {545return Ok((546AExprBuilder::new_from_node(fst.node())547.cast(DataType::Boolean, ctx.arena)548.node(),549fst.output_name().clone(),550));551}552553// Convert to binary expression as the optimizer understands those.554// Don't exceed 128 expressions as we might stackoverflow.555if e.len() < 128 {556let mut r = AExprBuilder::new_from_node(fst.node());557for expr in &e[1..] {558r = r.logical_or(expr.node(), ctx.arena);559}560return Ok((r.node(), fst.output_name().clone()));561}562563IB::AnyHorizontal564},565B::Not => IB::Not,566})567},568#[cfg(feature = "business")]569F::Business(business_function) => I::Business(match business_function {570BusinessFunction::BusinessDayCount {571week_mask,572holidays,573} => IRBusinessFunction::BusinessDayCount {574week_mask,575holidays,576},577BusinessFunction::AddBusinessDay {578week_mask,579holidays,580roll,581} => IRBusinessFunction::AddBusinessDay {582week_mask,583holidays,584roll,585},586BusinessFunction::IsBusinessDay {587week_mask,588holidays,589} => IRBusinessFunction::IsBusinessDay {590week_mask,591holidays,592},593}),594#[cfg(feature = "abs")]595F::Abs => I::Abs,596F::Negate => I::Negate,597#[cfg(feature = "hist")]598F::Hist {599bin_count,600include_category,601include_breakpoint,602} => I::Hist {603bin_count,604include_category,605include_breakpoint,606},607F::NullCount => I::NullCount,608F::Pow(pow_function) => I::Pow(match pow_function {609PowFunction::Generic => IRPowFunction::Generic,610PowFunction::Sqrt => IRPowFunction::Sqrt,611PowFunction::Cbrt => IRPowFunction::Cbrt,612}),613#[cfg(feature = "row_hash")]614F::Hash(s0, s1, s2, s3) => I::Hash(s0, s1, s2, s3),615#[cfg(feature = "arg_where")]616F::ArgWhere => I::ArgWhere,617#[cfg(feature = "index_of")]618F::IndexOf => I::IndexOf,619#[cfg(feature = "search_sorted")]620F::SearchSorted { side, descending } => I::SearchSorted { side, descending },621#[cfg(feature = "range")]622F::Range(range_function) => I::Range(match range_function {623RangeFunction::IntRange { step, dtype } => {624let dtype = dtype.into_datatype(ctx.schema)?;625polars_ensure!(e[0].is_scalar(ctx.arena), ShapeMismatch: "non-scalar start passed to `int_range`");626polars_ensure!(e[1].is_scalar(ctx.arena), ShapeMismatch: "non-scalar stop passed to `int_range`");627polars_ensure!(dtype.is_integer(), SchemaMismatch: "non-integer `dtype` passed to `int_range`: '{dtype}'");628IRRangeFunction::IntRange { step, dtype }629},630RangeFunction::IntRanges { dtype } => {631let dtype = dtype.into_datatype(ctx.schema)?;632polars_ensure!(dtype.is_integer(), SchemaMismatch: "non-integer `dtype` passed to `int_ranges`: '{dtype}'");633IRRangeFunction::IntRanges { dtype }634},635RangeFunction::LinearSpace { closed } => {636polars_ensure!(e[0].is_scalar(ctx.arena), ShapeMismatch: "non-scalar start passed to `linear_space`");637polars_ensure!(e[1].is_scalar(ctx.arena), ShapeMismatch: "non-scalar end passed to `linear_space`");638polars_ensure!(e[2].is_scalar(ctx.arena), ShapeMismatch: "non-scalar num_samples passed to `linear_space`");639IRRangeFunction::LinearSpace { closed }640},641RangeFunction::LinearSpaces {642closed,643array_width,644} => IRRangeFunction::LinearSpaces {645closed,646array_width,647},648#[cfg(all(feature = "range", feature = "dtype-date"))]649RangeFunction::DateRange {650interval,651closed,652arg_type,653} => {654use DateRangeArgs::*;655let arg_names = match arg_type {656StartEndSamples => vec!["start", "end", "num_samples"],657StartEndInterval => vec!["start", "end"],658StartIntervalSamples => vec!["start", "num_samples"],659EndIntervalSamples => vec!["end", "num_samples"],660};661for (idx, &name) in arg_names.iter().enumerate() {662polars_ensure!(e[idx].is_scalar(ctx.arena), ShapeMismatch: "non-scalar {name} passed to `date_range`");663}664IRRangeFunction::DateRange {665interval,666closed,667arg_type,668}669},670#[cfg(all(feature = "range", feature = "dtype-date"))]671RangeFunction::DateRanges {672interval,673closed,674arg_type,675} => IRRangeFunction::DateRanges {676interval,677closed,678arg_type,679},680#[cfg(all(feature = "range", feature = "dtype-datetime"))]681RangeFunction::DatetimeRange {682interval,683closed,684time_unit,685time_zone,686arg_type,687} => {688use DateRangeArgs::*;689let arg_names = match arg_type {690StartEndSamples => vec!["start", "end", "num_samples"],691StartEndInterval => vec!["start", "end"],692StartIntervalSamples => vec!["start", "num_samples"],693EndIntervalSamples => vec!["end", "num_samples"],694};695for (idx, &name) in arg_names.iter().enumerate() {696polars_ensure!(e[idx].is_scalar(ctx.arena), ShapeMismatch: "non-scalar {name} passed to `datetime_range`");697}698IRRangeFunction::DatetimeRange {699interval,700closed,701time_unit,702time_zone,703arg_type,704}705},706#[cfg(all(feature = "range", feature = "dtype-datetime"))]707RangeFunction::DatetimeRanges {708interval,709closed,710time_unit,711time_zone,712arg_type,713} => IRRangeFunction::DatetimeRanges {714interval,715closed,716time_unit,717time_zone,718arg_type,719},720#[cfg(all(feature = "range", feature = "dtype-time"))]721RangeFunction::TimeRange { interval, closed } => {722polars_ensure!(e[0].is_scalar(ctx.arena), ShapeMismatch: "non-scalar start passed to `time_range`");723polars_ensure!(e[1].is_scalar(ctx.arena), ShapeMismatch: "non-scalar end passed to `time_range`");724IRRangeFunction::TimeRange { interval, closed }725},726#[cfg(all(feature = "range", feature = "dtype-time"))]727RangeFunction::TimeRanges { interval, closed } => {728IRRangeFunction::TimeRanges { interval, closed }729},730}),731#[cfg(feature = "trigonometry")]732F::Trigonometry(trigonometric_function) => {733use {IRTrigonometricFunction as IT, TrigonometricFunction as T};734I::Trigonometry(match trigonometric_function {735T::Cos => IT::Cos,736T::Cot => IT::Cot,737T::Sin => IT::Sin,738T::Tan => IT::Tan,739T::ArcCos => IT::ArcCos,740T::ArcSin => IT::ArcSin,741T::ArcTan => IT::ArcTan,742T::Cosh => IT::Cosh,743T::Sinh => IT::Sinh,744T::Tanh => IT::Tanh,745T::ArcCosh => IT::ArcCosh,746T::ArcSinh => IT::ArcSinh,747T::ArcTanh => IT::ArcTanh,748T::Degrees => IT::Degrees,749T::Radians => IT::Radians,750})751},752#[cfg(feature = "trigonometry")]753F::Atan2 => I::Atan2,754#[cfg(feature = "sign")]755F::Sign => I::Sign,756F::FillNull => I::FillNull,757F::FillNullWithStrategy(fill_null_strategy) => I::FillNullWithStrategy(fill_null_strategy),758#[cfg(feature = "rolling_window")]759F::RollingExpr { function, options } => {760use RollingFunction as R;761use aexpr::IRRollingFunction as IR;762763I::RollingExpr {764function: match function {765R::Min => IR::Min,766R::Max => IR::Max,767R::Mean => IR::Mean,768R::Sum => IR::Sum,769R::Quantile => IR::Quantile,770R::Var => IR::Var,771R::Std => IR::Std,772R::Rank => IR::Rank,773#[cfg(feature = "moment")]774R::Skew => IR::Skew,775#[cfg(feature = "moment")]776R::Kurtosis => IR::Kurtosis,777#[cfg(feature = "cov")]778R::CorrCov {779corr_cov_options,780is_corr,781} => IR::CorrCov {782corr_cov_options,783is_corr,784},785R::Map(f) => IR::Map(f),786},787options,788}789},790#[cfg(feature = "rolling_window_by")]791F::RollingExprBy {792function_by,793options,794} => {795use RollingFunctionBy as R;796use aexpr::IRRollingFunctionBy as IR;797798I::RollingExprBy {799function_by: match function_by {800R::MinBy => IR::MinBy,801R::MaxBy => IR::MaxBy,802R::MeanBy => IR::MeanBy,803R::SumBy => IR::SumBy,804R::QuantileBy => IR::QuantileBy,805R::VarBy => IR::VarBy,806R::StdBy => IR::StdBy,807R::RankBy => IR::RankBy,808},809options,810}811},812F::Rechunk => I::Rechunk,813F::Append { upcast } => I::Append { upcast },814F::ShiftAndFill => {815polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");816polars_ensure!(&e[2].is_scalar(ctx.arena), ShapeMismatch: "'fill_value' must be a scalar value");817I::ShiftAndFill818},819F::Shift => {820polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");821I::Shift822},823F::DropNans => I::DropNans,824F::DropNulls => I::DropNulls,825#[cfg(feature = "mode")]826F::Mode { maintain_order } => I::Mode { maintain_order },827#[cfg(feature = "moment")]828F::Skew(v) => I::Skew(v),829#[cfg(feature = "moment")]830F::Kurtosis(l, r) => I::Kurtosis(l, r),831#[cfg(feature = "dtype-array")]832F::Reshape(reshape_dimensions) => I::Reshape(reshape_dimensions),833#[cfg(feature = "repeat_by")]834F::RepeatBy => I::RepeatBy,835F::ArgUnique => I::ArgUnique,836F::ArgMin => I::ArgMin,837F::ArgMax => I::ArgMax,838F::ArgSort {839descending,840nulls_last,841} => I::ArgSort {842descending,843nulls_last,844},845F::Product => I::Product,846#[cfg(feature = "rank")]847F::Rank { options, seed } => I::Rank { options, seed },848F::Repeat => {849polars_ensure!(&e[0].is_scalar(ctx.arena), ShapeMismatch: "'value' must be a scalar value");850polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");851I::Repeat852},853#[cfg(feature = "round_series")]854F::Clip { has_min, has_max } => I::Clip { has_min, has_max },855#[cfg(feature = "dtype-struct")]856F::AsStruct => I::AsStruct,857#[cfg(feature = "top_k")]858F::TopK { descending } => I::TopK { descending },859#[cfg(feature = "top_k")]860F::TopKBy { descending } => I::TopKBy { descending },861#[cfg(feature = "cum_agg")]862F::CumCount { reverse } => I::CumCount { reverse },863#[cfg(feature = "cum_agg")]864F::CumSum { reverse } => I::CumSum { reverse },865#[cfg(feature = "cum_agg")]866F::CumProd { reverse } => I::CumProd { reverse },867#[cfg(feature = "cum_agg")]868F::CumMin { reverse } => I::CumMin { reverse },869#[cfg(feature = "cum_agg")]870F::CumMax { reverse } => I::CumMax { reverse },871F::Reverse => I::Reverse,872#[cfg(feature = "dtype-struct")]873F::ValueCounts {874sort,875parallel,876name,877normalize,878} => I::ValueCounts {879sort,880parallel,881name,882normalize,883},884#[cfg(feature = "unique_counts")]885F::UniqueCounts => I::UniqueCounts,886#[cfg(feature = "approx_unique")]887F::ApproxNUnique => I::ApproxNUnique,888F::Coalesce => I::Coalesce,889#[cfg(feature = "diff")]890F::Diff(n) => {891polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");892I::Diff(n)893},894#[cfg(feature = "pct_change")]895F::PctChange => I::PctChange,896#[cfg(feature = "interpolate")]897F::Interpolate(interpolation_method) => I::Interpolate(interpolation_method),898#[cfg(feature = "interpolate_by")]899F::InterpolateBy => I::InterpolateBy,900#[cfg(feature = "log")]901F::Entropy { base, normalize } => I::Entropy { base, normalize },902#[cfg(feature = "log")]903F::Log => I::Log,904#[cfg(feature = "log")]905F::Log1p => I::Log1p,906#[cfg(feature = "log")]907F::Exp => I::Exp,908F::Unique(v) => I::Unique(v),909#[cfg(feature = "round_series")]910F::Round { decimals, mode } => I::Round { decimals, mode },911#[cfg(feature = "round_series")]912F::RoundSF { digits } => I::RoundSF { digits },913#[cfg(feature = "round_series")]914F::Floor => I::Floor,915#[cfg(feature = "round_series")]916F::Ceil => I::Ceil,917F::UpperBound => {918let field = e[0].field(ctx.schema, ctx.arena)?;919return Ok((920ctx.arena921.add(AExpr::Literal(field.dtype.to_physical().max()?.into())),922field.name,923));924},925F::LowerBound => {926let field = e[0].field(ctx.schema, ctx.arena)?;927return Ok((928ctx.arena929.add(AExpr::Literal(field.dtype.to_physical().min()?.into())),930field.name,931));932},933F::ConcatExpr(v) => I::ConcatExpr(v),934#[cfg(feature = "cov")]935F::Correlation { method } => {936use {CorrelationMethod as C, IRCorrelationMethod as IC};937I::Correlation {938method: match method {939C::Pearson => IC::Pearson,940#[cfg(all(feature = "rank", feature = "propagate_nans"))]941C::SpearmanRank(v) => IC::SpearmanRank(v),942C::Covariance(v) => IC::Covariance(v),943},944}945},946#[cfg(feature = "peaks")]947F::PeakMin => I::PeakMin,948#[cfg(feature = "peaks")]949F::PeakMax => I::PeakMax,950#[cfg(feature = "cutqcut")]951F::Cut {952breaks,953labels,954left_closed,955include_breaks,956} => I::Cut {957breaks,958labels,959left_closed,960include_breaks,961},962#[cfg(feature = "cutqcut")]963F::QCut {964probs,965labels,966left_closed,967allow_duplicates,968include_breaks,969} => I::QCut {970probs,971labels,972left_closed,973allow_duplicates,974include_breaks,975},976#[cfg(feature = "rle")]977F::RLE => I::RLE,978#[cfg(feature = "rle")]979F::RLEID => I::RLEID,980F::ToPhysical => I::ToPhysical,981#[cfg(feature = "random")]982F::Random { method, seed } => {983use {IRRandomMethod as IR, RandomMethod as R};984I::Random {985method: match method {986R::Shuffle => IR::Shuffle,987R::Sample {988is_fraction,989with_replacement,990shuffle,991} => IR::Sample {992is_fraction,993with_replacement,994shuffle,995},996},997seed,998}999},1000F::SetSortedFlag(is_sorted) => I::SetSortedFlag(is_sorted),1001#[cfg(feature = "ffi_plugin")]1002F::FfiPlugin {1003flags,1004lib,1005symbol,1006kwargs,1007} => I::FfiPlugin {1008flags,1009lib,1010symbol,1011kwargs,1012},10131014F::FoldHorizontal {1015callback,1016returns_scalar,1017return_dtype,1018} => I::FoldHorizontal {1019callback,1020returns_scalar,1021return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,1022},1023F::ReduceHorizontal {1024callback,1025returns_scalar,1026return_dtype,1027} => I::ReduceHorizontal {1028callback,1029returns_scalar,1030return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,1031},1032#[cfg(feature = "dtype-struct")]1033F::CumReduceHorizontal {1034callback,1035returns_scalar,1036return_dtype,1037} => I::CumReduceHorizontal {1038callback,1039returns_scalar,1040return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,1041},1042#[cfg(feature = "dtype-struct")]1043F::CumFoldHorizontal {1044callback,1045returns_scalar,1046return_dtype,1047include_init,1048} => I::CumFoldHorizontal {1049callback,1050returns_scalar,1051return_dtype: return_dtype.try_map(|dtype| dtype.into_datatype(ctx.schema))?,1052include_init,1053},10541055F::MaxHorizontal => I::MaxHorizontal,1056F::MinHorizontal => I::MinHorizontal,1057F::SumHorizontal { ignore_nulls } => I::SumHorizontal { ignore_nulls },1058F::MeanHorizontal { ignore_nulls } => I::MeanHorizontal { ignore_nulls },1059#[cfg(feature = "ewma")]1060F::EwmMean { options } => I::EwmMean { options },1061#[cfg(feature = "ewma_by")]1062F::EwmMeanBy { half_life } => I::EwmMeanBy { half_life },1063#[cfg(feature = "ewma")]1064F::EwmStd { options } => I::EwmStd { options },1065#[cfg(feature = "ewma")]1066F::EwmVar { options } => I::EwmVar { options },1067#[cfg(feature = "replace")]1068F::Replace => I::Replace,1069#[cfg(feature = "replace")]1070F::ReplaceStrict { return_dtype } => I::ReplaceStrict {1071return_dtype: match return_dtype {1072Some(dtype) => Some(dtype.into_datatype(ctx.schema)?),1073None => None,1074},1075},1076F::GatherEvery { n, offset } => I::GatherEvery { n, offset },1077#[cfg(feature = "reinterpret")]1078F::Reinterpret(v) => I::Reinterpret(v),1079F::ExtendConstant => {1080polars_ensure!(&e[1].is_scalar(ctx.arena), ShapeMismatch: "'value' must be a scalar value");1081polars_ensure!(&e[2].is_scalar(ctx.arena), ShapeMismatch: "'n' must be a scalar value");1082I::ExtendConstant1083},10841085F::RowEncode(v) => {1086let dts = e1087.iter()1088.map(|e| Ok(e.dtype(ctx.schema, ctx.arena)?.clone()))1089.collect::<PolarsResult<Vec<_>>>()?;1090I::RowEncode(dts, v)1091},1092#[cfg(feature = "dtype-struct")]1093F::RowDecode(fs, v) => I::RowDecode(1094fs.into_iter()1095.map(|(name, dt_expr)| Ok(Field::new(name, dt_expr.into_datatype(ctx.schema)?)))1096.collect::<PolarsResult<Vec<_>>>()?,1097v,1098),1099};11001101let mut options = ir_function.function_options();1102if set_elementwise {1103options.set_elementwise();1104}11051106// Handles special case functions like `struct.field`.1107let output_name = match ir_function.output_name().and_then(|v| v.into_inner()) {1108Some(name) => name,1109None if e.is_empty() => format_pl_smallstr!("{}", &ir_function),1110None => e[0].output_name().clone(),1111};11121113let ae_function = AExpr::Function {1114input: e,1115function: ir_function,1116options,1117};1118Ok((ctx.arena.add(ae_function), output_name))1119}112011211122