use std::ops::Sub;
use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions};
use polars_core::prelude::{
DataType, PolarsResult, QuantileMethod, Schema, TimeUnit, polars_bail, polars_err,
};
use polars_lazy::dsl::Expr;
use polars_ops::chunked_array::UnicodeForm;
use polars_ops::series::RoundMode;
use polars_plan::dsl::{coalesce, concat_str, len, max_horizontal, min_horizontal, when};
use polars_plan::plans::{DynLiteralValue, LiteralValue, typed_lit};
use polars_plan::prelude::{StrptimeOptions, col, cols, lit};
use polars_utils::pl_str::PlSmallStr;
use sqlparser::ast::helpers::attached_token::AttachedToken;
use sqlparser::ast::{
DateTimeField, DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg,
FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, Ident,
OrderByExpr, Value as SQLValue, WindowSpec, WindowType,
};
use sqlparser::tokenizer::Span;
use crate::SQLContext;
use crate::sql_expr::{adjust_one_indexed_param, parse_extract_date_part, parse_sql_expr};
pub(crate) struct SQLFunctionVisitor<'a> {
pub(crate) func: &'a SQLFunction,
pub(crate) ctx: &'a mut SQLContext,
pub(crate) active_schema: Option<&'a Schema>,
}
pub(crate) enum PolarsSQLFunctions {
BitAnd,
#[cfg(feature = "bitwise")]
BitCount,
BitOr,
BitXor,
Abs,
Ceil,
Div,
Exp,
Floor,
Pi,
Ln,
Log2,
Log10,
Log,
Log1p,
Pow,
Mod,
Sqrt,
Cbrt,
Round,
Sign,
Cos,
Cot,
Sin,
Tan,
CosD,
CotD,
SinD,
TanD,
Acos,
Asin,
Atan,
Atan2,
AcosD,
AsinD,
AtanD,
Atan2D,
Degrees,
Radians,
DatePart,
Strftime,
BitLength,
Concat,
ConcatWS,
Date,
EndsWith,
#[cfg(feature = "nightly")]
InitCap,
Left,
Length,
Lower,
LTrim,
Normalize,
OctetLength,
RegexpLike,
Replace,
Reverse,
Right,
RTrim,
SplitPart,
StartsWith,
StrPos,
Substring,
StringToArray,
Strptime,
Time,
Timestamp,
Upper,
Coalesce,
Greatest,
If,
IfNull,
Least,
NullIf,
Avg,
Corr,
Count,
CovarPop,
CovarSamp,
First,
Last,
Max,
Median,
QuantileCont,
QuantileDisc,
Min,
StdDev,
Sum,
Variance,
ArrayLength,
ArrayMin,
ArrayMax,
ArraySum,
ArrayMean,
ArrayReverse,
ArrayUnique,
Explode,
ArrayAgg,
ArrayToString,
ArrayGet,
ArrayContains,
Columns,
Udf(String),
}
impl PolarsSQLFunctions {
pub(crate) fn keywords() -> &'static [&'static str] {
&[
"abs",
"acos",
"acosd",
"array_contains",
"array_get",
"array_length",
"array_lower",
"array_mean",
"array_reverse",
"array_sum",
"array_to_string",
"array_unique",
"array_upper",
"asin",
"asind",
"atan",
"atan2",
"atan2d",
"atand",
"avg",
"bit_and",
"bit_count",
"bit_length",
"bit_or",
"bit_xor",
"cbrt",
"ceil",
"ceiling",
"char_length",
"character_length",
"coalesce",
"columns",
"concat",
"concat_ws",
"corr",
"cos",
"cosd",
"cot",
"cotd",
"count",
"covar",
"covar_pop",
"covar_samp",
"date",
"date_part",
"degrees",
"ends_with",
"exp",
"first",
"floor",
"greatest",
"if",
"ifnull",
"initcap",
"last",
"least",
"left",
"length",
"ln",
"log",
"log10",
"log1p",
"log2",
"lower",
"ltrim",
"max",
"median",
"quantile_disc",
"min",
"mod",
"nullif",
"octet_length",
"pi",
"pow",
"power",
"quantile_cont",
"quantile_disc",
"radians",
"regexp_like",
"replace",
"reverse",
"right",
"round",
"rtrim",
"sign",
"sin",
"sind",
"sqrt",
"starts_with",
"stddev",
"stddev_samp",
"stdev",
"stdev_samp",
"strftime",
"strpos",
"strptime",
"substr",
"sum",
"tan",
"tand",
"unnest",
"upper",
"var",
"var_samp",
"variance",
]
}
}
impl PolarsSQLFunctions {
fn try_from_sql(function: &'_ SQLFunction, ctx: &'_ SQLContext) -> PolarsResult<Self> {
let function_name = function.name.0[0].value.to_lowercase();
Ok(match function_name.as_str() {
"bit_and" | "bitand" => Self::BitAnd,
#[cfg(feature = "bitwise")]
"bit_count" | "bitcount" => Self::BitCount,
"bit_or" | "bitor" => Self::BitOr,
"bit_xor" | "bitxor" | "xor" => Self::BitXor,
"abs" => Self::Abs,
"cbrt" => Self::Cbrt,
"ceil" | "ceiling" => Self::Ceil,
"div" => Self::Div,
"exp" => Self::Exp,
"floor" => Self::Floor,
"ln" => Self::Ln,
"log" => Self::Log,
"log10" => Self::Log10,
"log1p" => Self::Log1p,
"log2" => Self::Log2,
"mod" => Self::Mod,
"pi" => Self::Pi,
"pow" | "power" => Self::Pow,
"round" => Self::Round,
"sign" => Self::Sign,
"sqrt" => Self::Sqrt,
"cos" => Self::Cos,
"cot" => Self::Cot,
"sin" => Self::Sin,
"tan" => Self::Tan,
"cosd" => Self::CosD,
"cotd" => Self::CotD,
"sind" => Self::SinD,
"tand" => Self::TanD,
"acos" => Self::Acos,
"asin" => Self::Asin,
"atan" => Self::Atan,
"atan2" => Self::Atan2,
"acosd" => Self::AcosD,
"asind" => Self::AsinD,
"atand" => Self::AtanD,
"atan2d" => Self::Atan2D,
"degrees" => Self::Degrees,
"radians" => Self::Radians,
"coalesce" => Self::Coalesce,
"greatest" => Self::Greatest,
"if" => Self::If,
"ifnull" => Self::IfNull,
"least" => Self::Least,
"nullif" => Self::NullIf,
"date_part" => Self::DatePart,
"strftime" => Self::Strftime,
"bit_length" => Self::BitLength,
"concat" => Self::Concat,
"concat_ws" => Self::ConcatWS,
"date" => Self::Date,
"timestamp" | "datetime" => Self::Timestamp,
"ends_with" => Self::EndsWith,
#[cfg(feature = "nightly")]
"initcap" => Self::InitCap,
"length" | "char_length" | "character_length" => Self::Length,
"left" => Self::Left,
"lower" => Self::Lower,
"ltrim" => Self::LTrim,
"normalize" => Self::Normalize,
"octet_length" => Self::OctetLength,
"strpos" => Self::StrPos,
"regexp_like" => Self::RegexpLike,
"replace" => Self::Replace,
"reverse" => Self::Reverse,
"right" => Self::Right,
"rtrim" => Self::RTrim,
"split_part" => Self::SplitPart,
"starts_with" => Self::StartsWith,
"string_to_array" => Self::StringToArray,
"strptime" => Self::Strptime,
"substr" => Self::Substring,
"time" => Self::Time,
"upper" => Self::Upper,
"avg" => Self::Avg,
"corr" => Self::Corr,
"count" => Self::Count,
"covar_pop" => Self::CovarPop,
"covar" | "covar_samp" => Self::CovarSamp,
"first" => Self::First,
"last" => Self::Last,
"max" => Self::Max,
"median" => Self::Median,
"quantile_cont" => Self::QuantileCont,
"quantile_disc" => Self::QuantileDisc,
"min" => Self::Min,
"stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev,
"sum" => Self::Sum,
"var" | "variance" | "var_samp" => Self::Variance,
"array_agg" => Self::ArrayAgg,
"array_contains" => Self::ArrayContains,
"array_get" => Self::ArrayGet,
"array_length" => Self::ArrayLength,
"array_lower" => Self::ArrayMin,
"array_mean" => Self::ArrayMean,
"array_reverse" => Self::ArrayReverse,
"array_sum" => Self::ArraySum,
"array_to_string" => Self::ArrayToString,
"array_unique" => Self::ArrayUnique,
"array_upper" => Self::ArrayMax,
"unnest" => Self::Explode,
"columns" => Self::Columns,
other => {
if ctx.function_registry.contains(other) {
Self::Udf(other.to_string())
} else {
polars_bail!(SQLInterface: "unsupported function '{}'", other);
}
},
})
}
}
impl SQLFunctionVisitor<'_> {
pub(crate) fn visit_function(&mut self) -> PolarsResult<Expr> {
use PolarsSQLFunctions::*;
use polars_lazy::prelude::Literal;
let function_name = PolarsSQLFunctions::try_from_sql(self.func, self.ctx)?;
let function = self.func;
if !function.within_group.is_empty() {
polars_bail!(SQLInterface: "'WITHIN GROUP' is not currently supported")
}
if function.filter.is_some() {
polars_bail!(SQLInterface: "'FILTER' is not currently supported")
}
if function.null_treatment.is_some() {
polars_bail!(SQLInterface: "'IGNORE|RESPECT NULLS' is not currently supported")
}
let log_with_base =
|e: Expr, base: f64| e.log(LiteralValue::Dyn(DynLiteralValue::Float(base)).lit());
match function_name {
BitAnd => self.visit_binary::<Expr>(Expr::and),
#[cfg(feature = "bitwise")]
BitCount => self.visit_unary(Expr::bitwise_count_ones),
BitOr => self.visit_binary::<Expr>(Expr::or),
BitXor => self.visit_binary::<Expr>(Expr::xor),
Abs => self.visit_unary(Expr::abs),
Cbrt => self.visit_unary(Expr::cbrt),
Ceil => self.visit_unary(Expr::ceil),
Div => self.visit_binary(|e, d| e.floor_div(d).cast(DataType::Int64)),
Exp => self.visit_unary(Expr::exp),
Floor => self.visit_unary(Expr::floor),
Ln => self.visit_unary(|e| log_with_base(e, std::f64::consts::E)),
Log => self.visit_binary(Expr::log),
Log10 => self.visit_unary(|e| log_with_base(e, 10.0)),
Log1p => self.visit_unary(Expr::log1p),
Log2 => self.visit_unary(|e| log_with_base(e, 2.0)),
Pi => self.visit_nullary(Expr::pi),
Mod => self.visit_binary(|e1, e2| e1 % e2),
Pow => self.visit_binary::<Expr>(Expr::pow),
Round => {
let args = extract_args(function)?;
match args.len() {
1 => self.visit_unary(|e| e.round(0, RoundMode::default())),
2 => self.try_visit_binary(|e, decimals| {
Ok(e.round(match decimals {
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
if n >= 0 { n as u32 } else {
polars_bail!(SQLInterface: "ROUND does not currently support negative decimals value ({})", args[1])
}
},
_ => polars_bail!(SQLSyntax: "invalid value for ROUND decimals ({})", args[1]),
}, RoundMode::default()))
}),
_ => polars_bail!(SQLSyntax: "ROUND expects 1-2 arguments (found {})", args.len()),
}
},
Sign => self.visit_unary(Expr::sign),
Sqrt => self.visit_unary(Expr::sqrt),
Acos => self.visit_unary(Expr::arccos),
AcosD => self.visit_unary(|e| e.arccos().degrees()),
Asin => self.visit_unary(Expr::arcsin),
AsinD => self.visit_unary(|e| e.arcsin().degrees()),
Atan => self.visit_unary(Expr::arctan),
Atan2 => self.visit_binary(Expr::arctan2),
Atan2D => self.visit_binary(|e, s| e.arctan2(s).degrees()),
AtanD => self.visit_unary(|e| e.arctan().degrees()),
Cos => self.visit_unary(Expr::cos),
CosD => self.visit_unary(|e| e.radians().cos()),
Cot => self.visit_unary(Expr::cot),
CotD => self.visit_unary(|e| e.radians().cot()),
Degrees => self.visit_unary(Expr::degrees),
Radians => self.visit_unary(Expr::radians),
Sin => self.visit_unary(Expr::sin),
SinD => self.visit_unary(|e| e.radians().sin()),
Tan => self.visit_unary(Expr::tan),
TanD => self.visit_unary(|e| e.radians().tan()),
Coalesce => self.visit_variadic(coalesce),
Greatest => self.visit_variadic(|exprs: &[Expr]| max_horizontal(exprs).unwrap()),
If => {
let args = extract_args(function)?;
match args.len() {
3 => self.try_visit_ternary(|cond: Expr, expr1: Expr, expr2: Expr| {
Ok(when(cond).then(expr1).otherwise(expr2))
}),
_ => {
polars_bail!(SQLSyntax: "IF expects 3 arguments (found {})", args.len()
)
},
}
},
IfNull => {
let args = extract_args(function)?;
match args.len() {
2 => self.visit_variadic(coalesce),
_ => {
polars_bail!(SQLSyntax: "IFNULL expects 2 arguments (found {})", args.len())
},
}
},
Least => self.visit_variadic(|exprs: &[Expr]| min_horizontal(exprs).unwrap()),
NullIf => {
let args = extract_args(function)?;
match args.len() {
2 => self.visit_binary(|l: Expr, r: Expr| {
when(l.clone().eq(r))
.then(lit(LiteralValue::untyped_null()))
.otherwise(l)
}),
_ => {
polars_bail!(SQLSyntax: "NULLIF expects 2 arguments (found {})", args.len())
},
}
},
DatePart => self.try_visit_binary(|part, e| {
match part {
Expr::Literal(p) if p.extract_str().is_some() => {
let p = p.extract_str().unwrap();
parse_extract_date_part(
e,
&DateTimeField::Custom(Ident {
value: p.to_string(),
quote_style: None,
span: Span::empty(),
}),
)
},
_ => {
polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART ({})", part);
},
}
}),
Strftime => {
let args = extract_args(function)?;
match args.len() {
2 => self.visit_binary(|e, fmt: String| e.dt().strftime(fmt.as_str())),
_ => {
polars_bail!(SQLSyntax: "STRFTIME expects 2 arguments (found {})", args.len())
},
}
},
BitLength => self.visit_unary(|e| e.str().len_bytes() * lit(8)),
Concat => {
let args = extract_args(function)?;
if args.is_empty() {
polars_bail!(SQLSyntax: "CONCAT expects at least 1 argument (found 0)");
} else {
self.visit_variadic(|exprs: &[Expr]| concat_str(exprs, "", true))
}
},
ConcatWS => {
let args = extract_args(function)?;
if args.len() < 2 {
polars_bail!(SQLSyntax: "CONCAT_WS expects at least 2 arguments (found {})", args.len());
} else {
self.try_visit_variadic(|exprs: &[Expr]| {
match &exprs[0] {
Expr::Literal(lv) if lv.extract_str().is_some() => Ok(concat_str(&exprs[1..], lv.extract_str().unwrap(), true)),
_ => polars_bail!(SQLSyntax: "CONCAT_WS 'separator' must be a literal string (found {:?})", exprs[0]),
}
})
}
},
Date => {
let args = extract_args(function)?;
match args.len() {
1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())),
2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)),
_ => {
polars_bail!(SQLSyntax: "DATE expects 1-2 arguments (found {})", args.len())
},
}
},
EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)),
#[cfg(feature = "nightly")]
InitCap => self.visit_unary(|e| e.str().to_titlecase()),
Left => self.try_visit_binary(|e, length| {
Ok(match length {
Expr::Literal(lv) if lv.is_null() => lit(lv),
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => lit(""),
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
let len = if n > 0 {
lit(n)
} else {
(e.clone().str().len_chars() + lit(n)).clip_min(lit(0))
};
e.str().slice(lit(0), len)
},
Expr::Literal(v) => {
polars_bail!(SQLSyntax: "invalid 'n_chars' for LEFT ({:?})", v)
},
_ => when(length.clone().gt_eq(lit(0)))
.then(e.clone().str().slice(lit(0), length.clone().abs()))
.otherwise(e.clone().str().slice(
lit(0),
(e.str().len_chars() + length.clone()).clip_min(lit(0)),
)),
})
}),
Length => self.visit_unary(|e| e.str().len_chars()),
Lower => self.visit_unary(|e| e.str().to_lowercase()),
LTrim => {
let args = extract_args(function)?;
match args.len() {
1 => self.visit_unary(|e| {
e.str().strip_chars_start(lit(LiteralValue::untyped_null()))
}),
2 => self.visit_binary(|e, s| e.str().strip_chars_start(s)),
_ => {
polars_bail!(SQLSyntax: "LTRIM expects 1-2 arguments (found {})", args.len())
},
}
},
Normalize => {
let args = extract_args(function)?;
match args.len() {
1 => self.visit_unary(|e| e.str().normalize(UnicodeForm::NFC)),
2 => {
let form = if let FunctionArgExpr::Expr(SQLExpr::Identifier(Ident {
value: s,
quote_style: None,
span: _,
})) = args[1]
{
match s.to_uppercase().as_str() {
"NFC" => UnicodeForm::NFC,
"NFD" => UnicodeForm::NFD,
"NFKC" => UnicodeForm::NFKC,
"NFKD" => UnicodeForm::NFKD,
_ => {
polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", s)
},
}
} else {
polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", args[1])
};
self.try_visit_binary(|e, _form: Expr| Ok(e.str().normalize(form.clone())))
},
_ => {
polars_bail!(SQLSyntax: "NORMALIZE expects 1-2 arguments (found {})", args.len())
},
}
},
OctetLength => self.visit_unary(|e| e.str().len_bytes()),
StrPos => {
self.visit_binary(|expr, substring| {
(expr.str().find(substring, true) + typed_lit(1u32)).fill_null(typed_lit(0u32))
})
},
RegexpLike => {
let args = extract_args(function)?;
match args.len() {
2 => self.visit_binary(|e, s| e.str().contains(s, true)),
3 => self.try_visit_ternary(|e, pat, flags| {
Ok(e.str().contains(
match (pat, flags) {
(Expr::Literal(s_lv), Expr::Literal(f_lv)) if s_lv.extract_str().is_some() && f_lv.extract_str().is_some() => {
let s = s_lv.extract_str().unwrap();
let f = f_lv.extract_str().unwrap();
if f.is_empty() {
polars_bail!(SQLSyntax: "invalid/empty 'flags' for REGEXP_LIKE ({})", args[2]);
};
lit(format!("(?{f}){s}"))
},
_ => {
polars_bail!(SQLSyntax: "invalid arguments for REGEXP_LIKE ({}, {})", args[1], args[2]);
},
},
true))
}),
_ => polars_bail!(SQLSyntax: "REGEXP_LIKE expects 2-3 arguments (found {})",args.len()),
}
},
Replace => {
let args = extract_args(function)?;
match args.len() {
3 => self
.try_visit_ternary(|e, old, new| Ok(e.str().replace_all(old, new, true))),
_ => {
polars_bail!(SQLSyntax: "REPLACE expects 3 arguments (found {})", args.len())
},
}
},
Reverse => self.visit_unary(|e| e.str().reverse()),
Right => self.try_visit_binary(|e, length| {
Ok(match length {
Expr::Literal(lv) if lv.is_null() => lit(lv),
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => typed_lit(""),
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
let n: i64 = n.try_into().unwrap();
let offset = if n < 0 {
lit(n.abs())
} else {
e.clone().str().len_chars().cast(DataType::Int32) - lit(n)
};
e.str().slice(offset, lit(LiteralValue::untyped_null()))
},
Expr::Literal(v) => {
polars_bail!(SQLSyntax: "invalid 'n_chars' for RIGHT ({:?})", v)
},
_ => when(length.clone().lt(lit(0)))
.then(
e.clone()
.str()
.slice(length.clone().abs(), lit(LiteralValue::untyped_null())),
)
.otherwise(e.clone().str().slice(
e.str().len_chars().cast(DataType::Int32) - length.clone(),
lit(LiteralValue::untyped_null()),
)),
})
}),
RTrim => {
let args = extract_args(function)?;
match args.len() {
1 => self.visit_unary(|e| {
e.str().strip_chars_end(lit(LiteralValue::untyped_null()))
}),
2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)),
_ => {
polars_bail!(SQLSyntax: "RTRIM expects 1-2 arguments (found {})", args.len())
},
}
},
SplitPart => {
let args = extract_args(function)?;
match args.len() {
3 => self.try_visit_ternary(|e, sep, idx| {
let idx = adjust_one_indexed_param(idx, true);
Ok(when(e.clone().is_not_null())
.then(
e.clone()
.str()
.split(sep)
.list()
.get(idx, true)
.fill_null(lit("")),
)
.otherwise(e))
}),
_ => {
polars_bail!(SQLSyntax: "SPLIT_PART expects 3 arguments (found {})", args.len())
},
}
},
StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)),
StringToArray => {
let args = extract_args(function)?;
match args.len() {
2 => self.visit_binary(|e, sep| e.str().split(sep)),
_ => {
polars_bail!(SQLSyntax: "STRING_TO_ARRAY expects 2 arguments (found {})", args.len())
},
}
},
Strptime => {
let args = extract_args(function)?;
match args.len() {
2 => self.visit_binary(|e, fmt: String| {
e.str().strptime(
DataType::Datetime(TimeUnit::Microseconds, None),
StrptimeOptions {
format: Some(fmt.into()),
..Default::default()
},
lit("latest"),
)
}),
_ => {
polars_bail!(SQLSyntax: "STRPTIME expects 2 arguments (found {})", args.len())
},
}
},
Time => {
let args = extract_args(function)?;
match args.len() {
1 => self.visit_unary(|e| e.str().to_time(StrptimeOptions::default())),
2 => self.visit_binary(|e, fmt| e.str().to_time(fmt)),
_ => {
polars_bail!(SQLSyntax: "TIME expects 1-2 arguments (found {})", args.len())
},
}
},
Timestamp => {
let args = extract_args(function)?;
match args.len() {
1 => self.visit_unary(|e| {
e.str()
.to_datetime(None, None, StrptimeOptions::default(), lit("latest"))
}),
2 => self
.visit_binary(|e, fmt| e.str().to_datetime(None, None, fmt, lit("latest"))),
_ => {
polars_bail!(SQLSyntax: "DATETIME expects 1-2 arguments (found {})", args.len())
},
}
},
Substring => {
let args = extract_args(function)?;
match args.len() {
2 => self.try_visit_binary(|e, start| {
Ok(match start {
Expr::Literal(lv) if lv.is_null() => lit(lv),
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null())),
Expr::Literal(_) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
_ => start.clone() + lit(1),
})
}),
3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| {
Ok(match (start.clone(), length.clone()) {
(Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
(_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", args[2])
},
(Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => e.str().slice(lit(n - 1), length),
(Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => {
e.str().slice(lit(0), (length + lit(n - 1)).clip_min(lit(0)))
},
(Expr::Literal(_), _) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
(_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
polars_bail!(SQLSyntax: "invalid 'length' for SUBSTR ({})", args[1])
},
_ => {
let adjusted_start = start - lit(1);
when(adjusted_start.clone().lt(lit(0)))
.then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0))))
.otherwise(e.str().slice(adjusted_start, length))
}
})
}),
_ => polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found {})", args.len()),
}
},
Upper => self.visit_unary(|e| e.str().to_uppercase()),
Avg => self.visit_unary(Expr::mean),
Corr => self.visit_binary(polars_lazy::dsl::pearson_corr),
Count => self.visit_count(),
CovarPop => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 0)),
CovarSamp => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 1)),
First => self.visit_unary(Expr::first),
Last => self.visit_unary(Expr::last),
Max => self.visit_unary_with_opt_cumulative(Expr::max, Expr::cum_max),
Median => self.visit_unary(Expr::median),
QuantileCont => {
let args = extract_args(function)?;
match args.len() {
2 => self.try_visit_binary(|e, q| {
let value = match q {
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
if (0.0..=1.0).contains(&f) {
Expr::from(f)
} else {
polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
}
},
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
if (0..=1).contains(&n) {
Expr::from(n as f64)
} else {
polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
}
},
_ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_CONT ({})", args[1])
};
Ok(e.quantile(value, QuantileMethod::Linear))
}),
_ => polars_bail!(SQLSyntax: "QUANTILE_CONT expects 2 arguments (found {})", args.len()),
}
},
QuantileDisc => {
let args = extract_args(function)?;
match args.len() {
2 => self.try_visit_binary(|e, q| {
let value = match q {
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
if (0.0..=1.0).contains(&f) {
Expr::from(f)
} else {
polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1])
}
},
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
if (0..=1).contains(&n) {
Expr::from(n as f64)
} else {
polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1])
}
},
_ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_DISC ({})", args[1])
};
Ok(e.quantile(value, QuantileMethod::Equiprobable))
}),
_ => polars_bail!(SQLSyntax: "QUANTILE_DISC expects 2 arguments (found {})", args.len()),
}
},
Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min),
StdDev => self.visit_unary(|e| e.std(1)),
Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum),
Variance => self.visit_unary(|e| e.var(1)),
ArrayAgg => self.visit_arr_agg(),
ArrayContains => self.visit_binary::<Expr>(|e, s| e.list().contains(s, true)),
ArrayGet => {
self.visit_binary(|e, idx: Expr| {
let idx = adjust_one_indexed_param(idx, true);
e.list().get(idx, true)
})
},
ArrayLength => self.visit_unary(|e| e.list().len()),
ArrayMax => self.visit_unary(|e| e.list().max()),
ArrayMean => self.visit_unary(|e| e.list().mean()),
ArrayMin => self.visit_unary(|e| e.list().min()),
ArrayReverse => self.visit_unary(|e| e.list().reverse()),
ArraySum => self.visit_unary(|e| e.list().sum()),
ArrayToString => self.visit_arr_to_string(),
ArrayUnique => self.visit_unary(|e| e.list().unique()),
Explode => self.visit_unary(|e| e.explode()),
Columns => {
let active_schema = self.active_schema;
self.try_visit_unary(|e: Expr| match e {
Expr::Literal(lv) if lv.extract_str().is_some() => {
let pat = lv.extract_str().unwrap();
if pat == "*" {
polars_bail!(
SQLSyntax: "COLUMNS('*') is not a valid regex; \
did you mean COLUMNS(*)?"
)
};
let pat = match pat {
_ if pat.starts_with('^') && pat.ends_with('$') => pat.to_string(),
_ if pat.starts_with('^') => format!("{pat}.*$"),
_ if pat.ends_with('$') => format!("^.*{pat}"),
_ => format!("^.*{pat}.*$"),
};
if let Some(active_schema) = &active_schema {
let rx = polars_utils::regex_cache::compile_regex(&pat).unwrap();
let col_names = active_schema
.iter_names()
.filter(|name| rx.is_match(name))
.cloned()
.collect::<Vec<_>>();
Ok(if col_names.len() == 1 {
col(col_names.into_iter().next().unwrap())
} else {
cols(col_names).as_expr()
})
} else {
Ok(col(pat.as_str()))
}
},
Expr::Selector(s) => Ok(s.as_expr()),
_ => polars_bail!(SQLSyntax: "COLUMNS expects a regex; found {:?}", e),
})
},
Udf(func_name) => self.visit_udf(&func_name),
}
}
fn visit_udf(&mut self, func_name: &str) -> PolarsResult<Expr> {
let args = extract_args(self.func)?
.into_iter()
.map(|arg| {
if let FunctionArgExpr::Expr(e) = arg {
parse_sql_expr(e, self.ctx, self.active_schema)
} else {
polars_bail!(SQLInterface: "only expressions are supported in UDFs")
}
})
.collect::<PolarsResult<Vec<_>>>()?;
Ok(self
.ctx
.function_registry
.get_udf(func_name)?
.ok_or_else(|| polars_err!(SQLInterface: "UDF {} not found", func_name))?
.call(args))
}
fn apply_cumulative_window(
&mut self,
f: impl Fn(Expr) -> Expr,
cumulative_f: impl Fn(Expr, bool) -> Expr,
WindowSpec {
partition_by,
order_by,
..
}: &WindowSpec,
) -> PolarsResult<Expr> {
if !order_by.is_empty() && partition_by.is_empty() {
let (order_by, desc): (Vec<Expr>, Vec<bool>) = order_by
.iter()
.map(|o| {
let expr = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?;
Ok(match o.asc {
Some(b) => (expr, !b),
None => (expr, false),
})
})
.collect::<PolarsResult<Vec<_>>>()?
.into_iter()
.unzip();
self.visit_unary_no_window(|e| {
cumulative_f(
e.sort_by(
&order_by,
SortMultipleOptions::default().with_order_descending_multi(desc.clone()),
),
false,
)
})
} else {
self.visit_unary(f)
}
}
fn visit_unary(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult<Expr> {
self.try_visit_unary(|e| Ok(f(e)))
}
fn try_visit_unary(&mut self, f: impl Fn(Expr) -> PolarsResult<Expr>) -> PolarsResult<Expr> {
let args = extract_args(self.func)?;
match args.as_slice() {
[FunctionArgExpr::Expr(sql_expr)] => {
f(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?)
},
[FunctionArgExpr::Wildcard] => f(parse_sql_expr(
&SQLExpr::Wildcard(AttachedToken::empty()),
self.ctx,
self.active_schema,
)?),
_ => self.not_supported_error(),
}
.and_then(|e| self.apply_window_spec(e, &self.func.over))
}
fn visit_unary_with_opt_cumulative(
&mut self,
f: impl Fn(Expr) -> Expr,
cumulative_f: impl Fn(Expr, bool) -> Expr,
) -> PolarsResult<Expr> {
match self.func.over.as_ref() {
Some(WindowType::WindowSpec(spec)) => {
self.apply_cumulative_window(f, cumulative_f, spec)
},
Some(WindowType::NamedWindow(named_window)) => polars_bail!(
SQLInterface: "Named windows are not currently supported; found {:?}",
named_window
),
_ => self.visit_unary(f),
}
}
fn visit_unary_no_window(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult<Expr> {
let args = extract_args(self.func)?;
match args.as_slice() {
[FunctionArgExpr::Expr(sql_expr)] => {
let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
Ok(f(expr))
},
_ => self.not_supported_error(),
}
}
fn visit_binary<Arg: FromSQLExpr>(
&mut self,
f: impl Fn(Expr, Arg) -> Expr,
) -> PolarsResult<Expr> {
self.try_visit_binary(|e, a| Ok(f(e, a)))
}
fn try_visit_binary<Arg: FromSQLExpr>(
&mut self,
f: impl Fn(Expr, Arg) -> PolarsResult<Expr>,
) -> PolarsResult<Expr> {
let args = extract_args(self.func)?;
match args.as_slice() {
[
FunctionArgExpr::Expr(sql_expr1),
FunctionArgExpr::Expr(sql_expr2),
] => {
let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?;
let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
f(expr1, expr2)
},
_ => self.not_supported_error(),
}
}
fn visit_variadic(&mut self, f: impl Fn(&[Expr]) -> Expr) -> PolarsResult<Expr> {
self.try_visit_variadic(|e| Ok(f(e)))
}
fn try_visit_variadic(
&mut self,
f: impl Fn(&[Expr]) -> PolarsResult<Expr>,
) -> PolarsResult<Expr> {
let args = extract_args(self.func)?;
let mut expr_args = vec![];
for arg in args {
if let FunctionArgExpr::Expr(sql_expr) = arg {
expr_args.push(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?);
} else {
return self.not_supported_error();
};
}
f(&expr_args)
}
fn try_visit_ternary<Arg: FromSQLExpr>(
&mut self,
f: impl Fn(Expr, Arg, Arg) -> PolarsResult<Expr>,
) -> PolarsResult<Expr> {
let args = extract_args(self.func)?;
match args.as_slice() {
[
FunctionArgExpr::Expr(sql_expr1),
FunctionArgExpr::Expr(sql_expr2),
FunctionArgExpr::Expr(sql_expr3),
] => {
let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?;
let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
let expr3 = Arg::from_sql_expr(sql_expr3, self.ctx)?;
f(expr1, expr2, expr3)
},
_ => self.not_supported_error(),
}
}
fn visit_nullary(&self, f: impl Fn() -> Expr) -> PolarsResult<Expr> {
let args = extract_args(self.func)?;
if !args.is_empty() {
return self.not_supported_error();
}
Ok(f())
}
fn visit_arr_agg(&mut self) -> PolarsResult<Expr> {
let (args, is_distinct, clauses) = extract_args_and_clauses(self.func)?;
match args.as_slice() {
[FunctionArgExpr::Expr(sql_expr)] => {
let mut base = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
if is_distinct {
base = base.unique_stable();
}
for clause in clauses {
match clause {
FunctionArgumentClause::OrderBy(order_exprs) => {
base = self.apply_order_by(base, order_exprs.as_slice())?;
},
FunctionArgumentClause::Limit(limit_expr) => {
let limit = parse_sql_expr(&limit_expr, self.ctx, self.active_schema)?;
match limit {
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))
if n >= 0 =>
{
base = base.head(Some(n as usize))
},
_ => {
polars_bail!(SQLSyntax: "LIMIT in ARRAY_AGG must be a positive integer")
},
};
},
_ => {},
}
}
Ok(base.implode())
},
_ => {
polars_bail!(SQLSyntax: "ARRAY_AGG must have exactly one argument; found {}", args.len())
},
}
}
fn visit_arr_to_string(&mut self) -> PolarsResult<Expr> {
let args = extract_args(self.func)?;
match args.len() {
2 => self.try_visit_binary(|e, sep| {
Ok(e.cast(DataType::List(Box::from(DataType::String)))
.list()
.join(sep, true))
}),
#[cfg(feature = "list_eval")]
3 => self.try_visit_ternary(|e, sep, null_value| match null_value {
Expr::Literal(lv) if lv.extract_str().is_some() => {
Ok(if lv.extract_str().unwrap().is_empty() {
e.cast(DataType::List(Box::from(DataType::String)))
.list()
.join(sep, true)
} else {
e.cast(DataType::List(Box::from(DataType::String)))
.list()
.eval(col("").fill_null(lit(lv.extract_str().unwrap())))
.list()
.join(sep, false)
})
},
_ => {
polars_bail!(SQLSyntax: "invalid null value for ARRAY_TO_STRING ({})", args[2])
},
}),
_ => {
polars_bail!(SQLSyntax: "ARRAY_TO_STRING expects 2-3 arguments (found {})", args.len())
},
}
}
fn visit_count(&mut self) -> PolarsResult<Expr> {
let (args, is_distinct) = extract_args_distinct(self.func)?;
let count_expr = match (is_distinct, args.as_slice()) {
(false, [FunctionArgExpr::Wildcard] | []) => len(),
(false, [FunctionArgExpr::Expr(sql_expr)]) => {
let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
expr.count()
},
(true, [FunctionArgExpr::Expr(sql_expr)]) => {
let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
expr.clone().n_unique().sub(expr.null_count().gt(lit(0)))
},
_ => self.not_supported_error()?,
};
self.apply_window_spec(count_expr, &self.func.over)
}
fn apply_order_by(&mut self, expr: Expr, order_by: &[OrderByExpr]) -> PolarsResult<Expr> {
let mut by = Vec::with_capacity(order_by.len());
let mut descending = Vec::with_capacity(order_by.len());
let mut nulls_last = Vec::with_capacity(order_by.len());
for ob in order_by {
let desc_order = !ob.asc.unwrap_or(true);
by.push(parse_sql_expr(&ob.expr, self.ctx, self.active_schema)?);
nulls_last.push(!ob.nulls_first.unwrap_or(desc_order));
descending.push(desc_order);
}
Ok(expr.sort_by(
by,
SortMultipleOptions::default()
.with_order_descending_multi(descending)
.with_nulls_last_multi(nulls_last)
.with_maintain_order(true),
))
}
fn apply_window_spec(
&mut self,
expr: Expr,
window_type: &Option<WindowType>,
) -> PolarsResult<Expr> {
Ok(match &window_type {
Some(WindowType::WindowSpec(window_spec)) => {
if window_spec.partition_by.is_empty() {
let exprs = window_spec
.order_by
.iter()
.map(|o| {
let e = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?;
Ok(o.asc.map_or(e.clone(), |b| {
e.sort(SortOptions::default().with_order_descending(!b))
}))
})
.collect::<PolarsResult<Vec<_>>>()?;
expr.over(exprs)
} else {
let partition_by = window_spec
.partition_by
.iter()
.map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
.collect::<PolarsResult<Vec<_>>>()?;
expr.over(partition_by)
}
},
Some(WindowType::NamedWindow(named_window)) => polars_bail!(
SQLInterface: "Named windows are not currently supported; found {:?}",
named_window
),
None => expr,
})
}
fn not_supported_error(&self) -> PolarsResult<Expr> {
polars_bail!(
SQLInterface:
"no function matches the given name and arguments: `{}`",
self.func.to_string()
);
}
}
fn extract_args(func: &SQLFunction) -> PolarsResult<Vec<&FunctionArgExpr>> {
let (args, _, _) = _extract_func_args(func, false, false)?;
Ok(args)
}
fn extract_args_distinct(func: &SQLFunction) -> PolarsResult<(Vec<&FunctionArgExpr>, bool)> {
let (args, is_distinct, _) = _extract_func_args(func, true, false)?;
Ok((args, is_distinct))
}
fn extract_args_and_clauses(
func: &SQLFunction,
) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
_extract_func_args(func, true, true)
}
fn _extract_func_args(
func: &SQLFunction,
get_distinct: bool,
get_clauses: bool,
) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
match &func.args {
FunctionArguments::List(FunctionArgumentList {
args,
duplicate_treatment,
clauses,
}) => {
let is_distinct = matches!(duplicate_treatment, Some(DuplicateTreatment::Distinct));
if !(get_clauses || get_distinct) && is_distinct {
polars_bail!(SQLSyntax: "unexpected use of DISTINCT found in '{}'", func.name)
} else if !get_clauses && !clauses.is_empty() {
polars_bail!(SQLSyntax: "unexpected clause found in '{}' ({})", func.name, clauses[0])
} else {
let unpacked_args = args
.iter()
.map(|arg| match arg {
FunctionArg::Named { arg, .. } => arg,
FunctionArg::ExprNamed { arg, .. } => arg,
FunctionArg::Unnamed(arg) => arg,
})
.collect();
Ok((unpacked_args, is_distinct, clauses.clone()))
}
},
FunctionArguments::Subquery { .. } => {
Err(polars_err!(SQLInterface: "subquery not expected in {}", func.name))
},
FunctionArguments::None => Ok((vec![], false, vec![])),
}
}
pub(crate) trait FromSQLExpr {
fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
where
Self: Sized;
}
impl FromSQLExpr for f64 {
fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
where
Self: Sized,
{
match expr {
SQLExpr::Value(v) => match v {
SQLValue::Number(s, _) => s
.parse()
.map_err(|_| polars_err!(SQLInterface: "cannot parse literal {:?}", s)),
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
},
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
}
}
}
impl FromSQLExpr for bool {
fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
where
Self: Sized,
{
match expr {
SQLExpr::Value(v) => match v {
SQLValue::Boolean(v) => Ok(*v),
_ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", v),
},
_ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", expr),
}
}
}
impl FromSQLExpr for String {
fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
where
Self: Sized,
{
match expr {
SQLExpr::Value(v) => match v {
SQLValue::SingleQuotedString(s) => Ok(s.clone()),
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
},
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
}
}
}
impl FromSQLExpr for StrptimeOptions {
fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
where
Self: Sized,
{
match expr {
SQLExpr::Value(v) => match v {
SQLValue::SingleQuotedString(s) => Ok(StrptimeOptions {
format: Some(PlSmallStr::from_str(s)),
..StrptimeOptions::default()
}),
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
},
_ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
}
}
}
impl FromSQLExpr for Expr {
fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
where
Self: Sized,
{
parse_sql_expr(expr, ctx, None)
}
}