Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-sql/src/sql_expr.rs
8336 views
1
//! Expressions that are supported by the Polars SQL interface.
2
//!
3
//! This is useful for syntax highlighting
4
//!
5
//! This module defines:
6
//! - all Polars SQL keywords [`all_keywords`]
7
//! - all Polars SQL functions [`all_functions`]
8
9
use std::fmt::Display;
10
use std::ops::Div;
11
12
use polars_core::prelude::*;
13
use polars_lazy::prelude::*;
14
use polars_plan::plans::DynLiteralValue;
15
use polars_plan::prelude::typed_lit;
16
use polars_time::Duration;
17
use polars_utils::unique_column_name;
18
#[cfg(feature = "serde")]
19
use serde::{Deserialize, Serialize};
20
use sqlparser::ast::{
21
AccessExpr, BinaryOperator as SQLBinaryOperator, CastFormat, CastKind, DataType as SQLDataType,
22
DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, Interval, Query as Subquery,
23
SelectItem, Subscript, TimezoneInfo, TrimWhereField, TypedString, UnaryOperator,
24
Value as SQLValue, ValueWithSpan,
25
};
26
use sqlparser::dialect::GenericDialect;
27
use sqlparser::parser::{Parser, ParserOptions};
28
29
use crate::SQLContext;
30
use crate::functions::SQLFunctionVisitor;
31
use crate::types::{
32
bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars,
33
};
34
35
#[inline]
36
#[cold]
37
#[must_use]
38
/// Convert a Display-able error to PolarsError::SQLInterface
39
pub fn to_sql_interface_err(err: impl Display) -> PolarsError {
40
PolarsError::SQLInterface(err.to_string().into())
41
}
42
43
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
44
#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
45
/// Categorises the type of (allowed) subquery constraint
46
pub enum SubqueryRestriction {
47
/// Subquery must return a single column
48
SingleColumn,
49
// SingleRow,
50
// SingleValue,
51
// Any
52
}
53
54
/// Recursively walks a SQL Expr to create a polars Expr
55
pub(crate) struct SQLExprVisitor<'a> {
56
ctx: &'a mut SQLContext,
57
active_schema: Option<&'a Schema>,
58
}
59
60
impl SQLExprVisitor<'_> {
61
fn array_expr_to_series(&mut self, elements: &[SQLExpr]) -> PolarsResult<Series> {
62
let mut array_elements = Vec::with_capacity(elements.len());
63
for e in elements {
64
let val = match e {
65
SQLExpr::Value(ValueWithSpan { value: v, .. }) => self.visit_any_value(v, None),
66
SQLExpr::UnaryOp { op, expr } => match expr.as_ref() {
67
SQLExpr::Value(ValueWithSpan { value: v, .. }) => {
68
self.visit_any_value(v, Some(op))
69
},
70
_ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
71
},
72
SQLExpr::Array(values) => {
73
let srs = self.array_expr_to_series(&values.elem)?;
74
Ok(AnyValue::List(srs))
75
},
76
_ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
77
}?
78
.into_static();
79
array_elements.push(val);
80
}
81
Series::from_any_values(PlSmallStr::EMPTY, &array_elements, true)
82
}
83
84
fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
85
match expr {
86
SQLExpr::AllOp {
87
left,
88
compare_op,
89
right,
90
} => self.visit_all(left, compare_op, right),
91
SQLExpr::AnyOp {
92
left,
93
compare_op,
94
right,
95
is_some: _,
96
} => self.visit_any(left, compare_op, right),
97
SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None),
98
SQLExpr::Between {
99
expr,
100
negated,
101
low,
102
high,
103
} => self.visit_between(expr, *negated, low, high),
104
SQLExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right),
105
SQLExpr::Cast {
106
kind,
107
expr,
108
data_type,
109
format,
110
} => self.visit_cast(expr, data_type, format, kind),
111
SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
112
SQLExpr::CompoundFieldAccess { root, access_chain } => {
113
// simple subscript access (eg: "array_col[1]")
114
if access_chain.len() == 1 {
115
match &access_chain[0] {
116
AccessExpr::Subscript(subscript) => {
117
return self.visit_subscript(root, subscript);
118
},
119
AccessExpr::Dot(_) => {
120
polars_bail!(SQLSyntax: "dot-notation field access is currently unsupported: {:?}", access_chain[0])
121
},
122
}
123
}
124
// chained dot/bracket notation (eg: "struct_col.field[2].foo[0].bar")
125
polars_bail!(SQLSyntax: "complex field access chains are currently unsupported: {:?}", access_chain[0])
126
},
127
SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
128
SQLExpr::Extract {
129
field,
130
syntax: _,
131
expr,
132
} => parse_extract_date_part(self.visit_expr(expr)?, field),
133
SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
134
SQLExpr::Function(function) => self.visit_function(function),
135
SQLExpr::Identifier(ident) => self.visit_identifier(ident),
136
SQLExpr::InList {
137
expr,
138
list,
139
negated,
140
} => {
141
let expr = self.visit_expr(expr)?;
142
let elems = self.visit_array_expr(list, true, Some(&expr))?;
143
let is_in = expr.is_in(elems, false);
144
Ok(if *negated { is_in.not() } else { is_in })
145
},
146
SQLExpr::InSubquery {
147
expr,
148
subquery,
149
negated,
150
} => self.visit_in_subquery(expr, subquery, *negated),
151
SQLExpr::Interval(interval) => Ok(lit(interval_to_duration(interval, true)?)),
152
SQLExpr::IsDistinctFrom(e1, e2) => {
153
Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?))
154
},
155
SQLExpr::IsFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false))),
156
SQLExpr::IsNotDistinctFrom(e1, e2) => {
157
Ok(self.visit_expr(e1)?.eq_missing(self.visit_expr(e2)?))
158
},
159
SQLExpr::IsNotFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false)).not()),
160
SQLExpr::IsNotNull(expr) => Ok(self.visit_expr(expr)?.is_not_null()),
161
SQLExpr::IsNotTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true)).not()),
162
SQLExpr::IsNull(expr) => Ok(self.visit_expr(expr)?.is_null()),
163
SQLExpr::IsTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true))),
164
SQLExpr::Like {
165
negated,
166
any,
167
expr,
168
pattern,
169
escape_char,
170
} => {
171
if *any {
172
polars_bail!(SQLSyntax: "LIKE ANY is not a supported syntax")
173
}
174
let escape_str = escape_char.as_ref().and_then(|v| match v {
175
SQLValue::SingleQuotedString(s) => Some(s.clone()),
176
_ => None,
177
});
178
self.visit_like(*negated, expr, pattern, &escape_str, false)
179
},
180
SQLExpr::ILike {
181
negated,
182
any,
183
expr,
184
pattern,
185
escape_char,
186
} => {
187
if *any {
188
polars_bail!(SQLSyntax: "ILIKE ANY is not a supported syntax")
189
}
190
let escape_str = escape_char.as_ref().and_then(|v| match v {
191
SQLValue::SingleQuotedString(s) => Some(s.clone()),
192
_ => None,
193
});
194
self.visit_like(*negated, expr, pattern, &escape_str, true)
195
},
196
SQLExpr::Nested(expr) => self.visit_expr(expr),
197
SQLExpr::Position { expr, r#in } => Ok(
198
// note: SQL is 1-indexed
199
(self
200
.visit_expr(r#in)?
201
.str()
202
.find(self.visit_expr(expr)?, true)
203
+ typed_lit(1u32))
204
.fill_null(typed_lit(0u32)),
205
),
206
SQLExpr::RLike {
207
// note: parses both RLIKE and REGEXP
208
negated,
209
expr,
210
pattern,
211
regexp: _,
212
} => {
213
let matches = self
214
.visit_expr(expr)?
215
.str()
216
.contains(self.visit_expr(pattern)?, true);
217
Ok(if *negated { matches.not() } else { matches })
218
},
219
SQLExpr::Subquery(_) => polars_bail!(SQLInterface: "unexpected subquery"),
220
SQLExpr::Substring {
221
expr,
222
substring_from,
223
substring_for,
224
..
225
} => self.visit_substring(expr, substring_from.as_deref(), substring_for.as_deref()),
226
SQLExpr::Trim {
227
expr,
228
trim_where,
229
trim_what,
230
trim_characters,
231
} => self.visit_trim(expr, trim_where, trim_what, trim_characters),
232
SQLExpr::TypedString(TypedString {
233
data_type,
234
value:
235
ValueWithSpan {
236
value: SQLValue::SingleQuotedString(v),
237
..
238
},
239
uses_odbc_syntax: _,
240
}) => match data_type {
241
SQLDataType::Date => {
242
if is_iso_date(v) {
243
Ok(lit(v.as_str()).cast(DataType::Date))
244
} else {
245
polars_bail!(SQLSyntax: "invalid DATE literal '{}'", v)
246
}
247
},
248
SQLDataType::Time(None, TimezoneInfo::None) => {
249
if is_iso_time(v) {
250
Ok(lit(v.as_str()).str().to_time(StrptimeOptions {
251
strict: true,
252
..Default::default()
253
}))
254
} else {
255
polars_bail!(SQLSyntax: "invalid TIME literal '{}'", v)
256
}
257
},
258
SQLDataType::Timestamp(None, TimezoneInfo::None) | SQLDataType::Datetime(None) => {
259
if is_iso_datetime(v) {
260
Ok(lit(v.as_str()).str().to_datetime(
261
None,
262
None,
263
StrptimeOptions {
264
strict: true,
265
..Default::default()
266
},
267
lit("latest"),
268
))
269
} else {
270
let fn_name = match data_type {
271
SQLDataType::Timestamp(_, _) => "TIMESTAMP",
272
SQLDataType::Datetime(_) => "DATETIME",
273
_ => unreachable!(),
274
};
275
polars_bail!(SQLSyntax: "invalid {} literal '{}'", fn_name, v)
276
}
277
},
278
_ => {
279
polars_bail!(SQLInterface: "typed literal should be one of DATE, DATETIME, TIME, or TIMESTAMP (found {})", data_type)
280
},
281
},
282
SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
283
SQLExpr::Value(ValueWithSpan { value, .. }) => self.visit_literal(value),
284
SQLExpr::Wildcard(_) => Ok(all().as_expr()),
285
e @ SQLExpr::Case { .. } => self.visit_case_when_then(e),
286
other => {
287
polars_bail!(SQLInterface: "expression {:?} is not currently supported", other)
288
},
289
}
290
}
291
292
fn visit_subquery(
293
&mut self,
294
subquery: &Subquery,
295
restriction: SubqueryRestriction,
296
) -> PolarsResult<Expr> {
297
if subquery.with.is_some() {
298
polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE 'WITH' clause");
299
}
300
// note: we have to execute subqueries in an isolated scope to prevent
301
// propagating any context/arena mutation into the rest of the query
302
let lf = self
303
.ctx
304
.execute_isolated(|ctx| ctx.execute_query_no_ctes(subquery))?;
305
306
if restriction == SubqueryRestriction::SingleColumn {
307
let new_name = unique_column_name();
308
return Ok(Expr::SubPlan(
309
SpecialEq::new(Arc::new(lf.logical_plan)),
310
// TODO: pass the implode depending on expr.
311
vec![(
312
new_name.clone(),
313
first().as_expr().implode().alias(new_name.clone()),
314
)],
315
));
316
};
317
polars_bail!(SQLInterface: "subquery type not supported");
318
}
319
320
/// Visit a single SQL identifier.
321
///
322
/// e.g. column
323
fn visit_identifier(&self, ident: &Ident) -> PolarsResult<Expr> {
324
Ok(col(ident.value.as_str()))
325
}
326
327
/// Visit a compound SQL identifier
328
///
329
/// e.g. tbl.column, struct.field, tbl.struct.field (inc. nested struct fields)
330
fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult<Expr> {
331
Ok(resolve_compound_identifier(self.ctx, idents, self.active_schema)?[0].clone())
332
}
333
334
fn visit_like(
335
&mut self,
336
negated: bool,
337
expr: &SQLExpr,
338
pattern: &SQLExpr,
339
escape_char: &Option<String>,
340
case_insensitive: bool,
341
) -> PolarsResult<Expr> {
342
if escape_char.is_some() {
343
polars_bail!(SQLInterface: "ESCAPE char for LIKE/ILIKE is not currently supported; found '{}'", escape_char.clone().unwrap());
344
}
345
let pat = match self.visit_expr(pattern) {
346
Ok(Expr::Literal(lv)) if lv.extract_str().is_some() => {
347
PlSmallStr::from_str(lv.extract_str().unwrap())
348
},
349
_ => {
350
polars_bail!(SQLSyntax: "LIKE/ILIKE pattern must be a string literal; found {}", pattern)
351
},
352
};
353
if pat.is_empty() || (!case_insensitive && pat.chars().all(|c| !matches!(c, '%' | '_'))) {
354
// empty string or other exact literal match (eg: no wildcard chars)
355
let op = if negated {
356
SQLBinaryOperator::NotEq
357
} else {
358
SQLBinaryOperator::Eq
359
};
360
self.visit_binary_op(expr, &op, pattern)
361
} else {
362
// create regex from pattern containing SQL wildcard chars ('%' => '.*', '_' => '.')
363
let mut rx = regex::escape(pat.as_str())
364
.replace('%', ".*")
365
.replace('_', ".");
366
367
rx = format!(
368
"^{}{}$",
369
if case_insensitive { "(?is)" } else { "(?s)" },
370
rx
371
);
372
373
let expr = self.visit_expr(expr)?;
374
let matches = expr.str().contains(lit(rx), true);
375
Ok(if negated { matches.not() } else { matches })
376
}
377
}
378
379
fn visit_subscript(&mut self, expr: &SQLExpr, subscript: &Subscript) -> PolarsResult<Expr> {
380
let expr = self.visit_expr(expr)?;
381
Ok(match subscript {
382
Subscript::Index { index } => {
383
let idx = adjust_one_indexed_param(self.visit_expr(index)?, true);
384
expr.list().get(idx, true)
385
},
386
Subscript::Slice { .. } => {
387
polars_bail!(SQLSyntax: "array slice syntax is not currently supported")
388
},
389
})
390
}
391
392
/// Handle implicit temporal string comparisons.
393
///
394
/// eg: clauses such as -
395
/// "dt >= '2024-04-30'"
396
/// "dt = '2077-10-10'::date"
397
/// "dtm::date = '2077-10-10'
398
fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr {
399
if let (Some(name), Some(s), expr_dtype) = match (left, right) {
400
// identify "col <op> string" expressions
401
(Expr::Column(name), Expr::Literal(lv)) if lv.extract_str().is_some() => {
402
(Some(name.clone()), Some(lv.extract_str().unwrap()), None)
403
},
404
// identify "CAST(expr AS type) <op> string" and/or "expr::type <op> string" expressions
405
(Expr::Cast { expr, dtype, .. }, Expr::Literal(lv)) if lv.extract_str().is_some() => {
406
let s = lv.extract_str().unwrap();
407
match &**expr {
408
Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)),
409
_ => (None, Some(s), Some(dtype)),
410
}
411
},
412
_ => (None, None, None),
413
} {
414
if expr_dtype.is_none() && self.active_schema.is_none() {
415
right.clone()
416
} else {
417
let left_dtype = expr_dtype.map_or_else(
418
|| {
419
self.active_schema
420
.as_ref()
421
.and_then(|schema| schema.get(&name))
422
},
423
|dt| dt.as_literal(),
424
);
425
match left_dtype {
426
Some(DataType::Time) if is_iso_time(s) => {
427
right.clone().str().to_time(StrptimeOptions {
428
strict: true,
429
..Default::default()
430
})
431
},
432
Some(DataType::Date) if is_iso_date(s) => {
433
right.clone().str().to_date(StrptimeOptions {
434
strict: true,
435
..Default::default()
436
})
437
},
438
Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => {
439
if s.len() == 10 {
440
// handle upcast from ISO date string (10 chars) to datetime
441
lit(format!("{s}T00:00:00"))
442
} else {
443
lit(s.replacen(' ', "T", 1))
444
}
445
.str()
446
.to_datetime(
447
Some(*tu),
448
tz.clone(),
449
StrptimeOptions {
450
strict: true,
451
..Default::default()
452
},
453
lit("latest"),
454
)
455
},
456
_ => right.clone(),
457
}
458
}
459
} else {
460
right.clone()
461
}
462
}
463
464
fn struct_field_access_expr(
465
&mut self,
466
expr: &Expr,
467
path: &str,
468
infer_index: bool,
469
) -> PolarsResult<Expr> {
470
let path_elems = if path.starts_with('{') && path.ends_with('}') {
471
path.trim_matches(|c| c == '{' || c == '}')
472
} else {
473
path
474
}
475
.split(',');
476
477
let mut expr = expr.clone();
478
for p in path_elems {
479
let p = p.trim();
480
expr = if infer_index {
481
match p.parse::<i64>() {
482
Ok(idx) => expr.list().get(lit(idx), true),
483
Err(_) => expr.struct_().field_by_name(p),
484
}
485
} else {
486
expr.struct_().field_by_name(p)
487
}
488
}
489
Ok(expr)
490
}
491
492
/// Visit a SQL binary operator.
493
///
494
/// e.g. "column + 1", "column1 <= column2"
495
fn visit_binary_op(
496
&mut self,
497
left: &SQLExpr,
498
op: &SQLBinaryOperator,
499
right: &SQLExpr,
500
) -> PolarsResult<Expr> {
501
// check for (unsupported) scalar subquery comparisons
502
if matches!(left, SQLExpr::Subquery(_)) || matches!(right, SQLExpr::Subquery(_)) {
503
let (suggestion, str_op) = match op {
504
SQLBinaryOperator::NotEq => ("; use 'NOT IN' instead", "!=".to_string()),
505
SQLBinaryOperator::Eq => ("; use 'IN' instead", format!("{op}")),
506
_ => ("", format!("{op}")),
507
};
508
polars_bail!(
509
SQLSyntax: "subquery comparisons with '{str_op}' are not supported{suggestion}"
510
);
511
}
512
513
// need special handling for interval offsets and comparisons
514
let (lhs, mut rhs) = match (left, op, right) {
515
(_, SQLBinaryOperator::Minus, SQLExpr::Interval(v)) => {
516
let duration = interval_to_duration(v, false)?;
517
return Ok(self
518
.visit_expr(left)?
519
.dt()
520
.offset_by(lit(format!("-{duration}"))));
521
},
522
(_, SQLBinaryOperator::Plus, SQLExpr::Interval(v)) => {
523
let duration = interval_to_duration(v, false)?;
524
return Ok(self
525
.visit_expr(left)?
526
.dt()
527
.offset_by(lit(format!("{duration}"))));
528
},
529
(SQLExpr::Interval(v1), _, SQLExpr::Interval(v2)) => {
530
// shortcut interval comparison evaluation (-> bool)
531
let d1 = interval_to_duration(v1, false)?;
532
let d2 = interval_to_duration(v2, false)?;
533
let res = match op {
534
SQLBinaryOperator::Gt => Ok(lit(d1 > d2)),
535
SQLBinaryOperator::Lt => Ok(lit(d1 < d2)),
536
SQLBinaryOperator::GtEq => Ok(lit(d1 >= d2)),
537
SQLBinaryOperator::LtEq => Ok(lit(d1 <= d2)),
538
SQLBinaryOperator::NotEq => Ok(lit(d1 != d2)),
539
SQLBinaryOperator::Eq | SQLBinaryOperator::Spaceship => Ok(lit(d1 == d2)),
540
_ => polars_bail!(SQLInterface: "invalid interval comparison operator"),
541
};
542
if res.is_ok() {
543
return res;
544
}
545
(self.visit_expr(left)?, self.visit_expr(right)?)
546
},
547
_ => (self.visit_expr(left)?, self.visit_expr(right)?),
548
};
549
rhs = self.convert_temporal_strings(&lhs, &rhs);
550
551
Ok(match op {
552
// ----
553
// Bitwise operators
554
// ----
555
SQLBinaryOperator::BitwiseAnd => lhs.and(rhs), // "x & y"
556
SQLBinaryOperator::BitwiseOr => lhs.or(rhs), // "x | y"
557
SQLBinaryOperator::Xor => lhs.xor(rhs), // "x XOR y"
558
559
// ----
560
// General operators
561
// ----
562
SQLBinaryOperator::And => lhs.and(rhs), // "x AND y"
563
SQLBinaryOperator::Divide => lhs / rhs, // "x / y"
564
SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64), // "x // y"
565
SQLBinaryOperator::Eq => lhs.eq(rhs), // "x = y"
566
SQLBinaryOperator::Gt => lhs.gt(rhs), // "x > y"
567
SQLBinaryOperator::GtEq => lhs.gt_eq(rhs), // "x >= y"
568
SQLBinaryOperator::Lt => lhs.lt(rhs), // "x < y"
569
SQLBinaryOperator::LtEq => lhs.lt_eq(rhs), // "x <= y"
570
SQLBinaryOperator::Minus => lhs - rhs, // "x - y"
571
SQLBinaryOperator::Modulo => lhs % rhs, // "x % y"
572
SQLBinaryOperator::Multiply => lhs * rhs, // "x * y"
573
SQLBinaryOperator::NotEq => lhs.eq(rhs).not(), // "x != y"
574
SQLBinaryOperator::Or => lhs.or(rhs), // "x OR y"
575
SQLBinaryOperator::Plus => lhs + rhs, // "x + y"
576
SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs), // "x <=> y"
577
SQLBinaryOperator::StringConcat => { // "x || y"
578
lhs.cast(DataType::String) + rhs.cast(DataType::String)
579
},
580
SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs), // "x ^@ y"
581
// ----
582
// Regular expression operators
583
// ----
584
SQLBinaryOperator::PGRegexMatch => match rhs { // "x ~ y"
585
Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true),
586
_ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs),
587
},
588
SQLBinaryOperator::PGRegexNotMatch => match rhs { // "x !~ y"
589
Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true).not(),
590
_ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs),
591
},
592
SQLBinaryOperator::PGRegexIMatch => match rhs { // "x ~* y"
593
Expr::Literal(ref lv) if lv.extract_str().is_some() => {
594
let pat = lv.extract_str().unwrap();
595
lhs.str().contains(lit(format!("(?i){pat}")), true)
596
},
597
_ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs),
598
},
599
SQLBinaryOperator::PGRegexNotIMatch => match rhs { // "x !~* y"
600
Expr::Literal(ref lv) if lv.extract_str().is_some() => {
601
let pat = lv.extract_str().unwrap();
602
lhs.str().contains(lit(format!("(?i){pat}")), true).not()
603
},
604
_ => {
605
polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", rhs)
606
},
607
},
608
// ----
609
// LIKE/ILIKE operators
610
// ----
611
SQLBinaryOperator::PGLikeMatch // "x ~~ y"
612
| SQLBinaryOperator::PGNotLikeMatch // "x !~~ y"
613
| SQLBinaryOperator::PGILikeMatch // "x ~~* y"
614
| SQLBinaryOperator::PGNotILikeMatch => { // "x !~~* y"
615
let expr = if matches!(
616
op,
617
SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch
618
) {
619
SQLExpr::Like {
620
negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch),
621
any: false,
622
expr: Box::new(left.clone()),
623
pattern: Box::new(right.clone()),
624
escape_char: None,
625
}
626
} else {
627
SQLExpr::ILike {
628
negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch),
629
any: false,
630
expr: Box::new(left.clone()),
631
pattern: Box::new(right.clone()),
632
escape_char: None,
633
}
634
};
635
self.visit_expr(&expr)?
636
},
637
// ----
638
// JSON/Struct field access operators
639
// ----
640
SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs { // "x -> y", "x ->> y"
641
Expr::Literal(lv) if lv.extract_str().is_some() => {
642
let path = lv.extract_str().unwrap();
643
let mut expr = self.struct_field_access_expr(&lhs, path, false)?;
644
if let SQLBinaryOperator::LongArrow = op {
645
expr = expr.cast(DataType::String);
646
}
647
expr
648
},
649
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(idx))) => {
650
let mut expr = self.struct_field_access_expr(&lhs, &idx.to_string(), true)?;
651
if let SQLBinaryOperator::LongArrow = op {
652
expr = expr.cast(DataType::String);
653
}
654
expr
655
},
656
_ => {
657
polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right)
658
},
659
},
660
SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => { // "x #> y", "x #>> y"
661
match rhs {
662
Expr::Literal(lv) if lv.extract_str().is_some() => {
663
let path = lv.extract_str().unwrap();
664
let mut expr = self.struct_field_access_expr(&lhs, path, true)?;
665
if let SQLBinaryOperator::HashLongArrow = op {
666
expr = expr.cast(DataType::String);
667
}
668
expr
669
},
670
_ => {
671
polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", rhs)
672
}
673
}
674
},
675
other => {
676
polars_bail!(SQLInterface: "operator {:?} is not currently supported", other)
677
},
678
})
679
}
680
681
/// Visit a SQL unary operator.
682
///
683
/// e.g. +column or -column
684
fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
685
let expr = self.visit_expr(expr)?;
686
Ok(match (op, expr.clone()) {
687
// simplify the parse tree by special-casing common unary +/- ops
688
(UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
689
lit(n)
690
},
691
(UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
692
lit(n)
693
},
694
(UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
695
lit(-n)
696
},
697
(UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
698
lit(-n)
699
},
700
// general case
701
(UnaryOperator::Plus, _) => lit(0) + expr,
702
(UnaryOperator::Minus, _) => lit(0) - expr,
703
(UnaryOperator::Not, _) => match &expr {
704
Expr::Column(name)
705
if self
706
.active_schema
707
.and_then(|schema| schema.get(name))
708
.is_some_and(|dtype| matches!(dtype, DataType::Boolean)) =>
709
{
710
// if already boolean, can operate bitwise
711
expr.not()
712
},
713
// otherwise SQL "NOT" expects logical, not bitwise, behaviour (eg: on integers)
714
_ => expr.strict_cast(DataType::Boolean).not(),
715
},
716
other => polars_bail!(SQLInterface: "unary operator {:?} is not supported", other),
717
})
718
}
719
720
/// Visit a SQL function.
721
///
722
/// e.g. SUM(column) or COUNT(*)
723
///
724
/// See [SQLFunctionVisitor] for more details
725
fn visit_function(&mut self, function: &SQLFunction) -> PolarsResult<Expr> {
726
let mut visitor = SQLFunctionVisitor {
727
func: function,
728
ctx: self.ctx,
729
active_schema: self.active_schema,
730
};
731
visitor.visit_function()
732
}
733
734
/// Visit a SQL `ALL` expression.
735
///
736
/// e.g. `a > ALL(y)`
737
fn visit_all(
738
&mut self,
739
left: &SQLExpr,
740
compare_op: &SQLBinaryOperator,
741
right: &SQLExpr,
742
) -> PolarsResult<Expr> {
743
let left = self.visit_expr(left)?;
744
let right = self.visit_expr(right)?;
745
746
match compare_op {
747
SQLBinaryOperator::Gt => Ok(left.gt(right.max())),
748
SQLBinaryOperator::Lt => Ok(left.lt(right.min())),
749
SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.max())),
750
SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.min())),
751
SQLBinaryOperator::Eq => polars_bail!(SQLSyntax: "ALL cannot be used with ="),
752
SQLBinaryOperator::NotEq => polars_bail!(SQLSyntax: "ALL cannot be used with !="),
753
_ => polars_bail!(SQLInterface: "invalid comparison operator"),
754
}
755
}
756
757
/// Visit a SQL `ANY` expression.
758
///
759
/// e.g. `a != ANY(y)`
760
fn visit_any(
761
&mut self,
762
left: &SQLExpr,
763
compare_op: &SQLBinaryOperator,
764
right: &SQLExpr,
765
) -> PolarsResult<Expr> {
766
let left = self.visit_expr(left)?;
767
let right = self.visit_expr(right)?;
768
769
match compare_op {
770
SQLBinaryOperator::Gt => Ok(left.gt(right.min())),
771
SQLBinaryOperator::Lt => Ok(left.lt(right.max())),
772
SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.min())),
773
SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.max())),
774
SQLBinaryOperator::Eq => Ok(left.is_in(right, false)),
775
SQLBinaryOperator::NotEq => Ok(left.is_in(right, false).not()),
776
_ => polars_bail!(SQLInterface: "invalid comparison operator"),
777
}
778
}
779
780
/// Visit a SQL `ARRAY` list (including `IN` values).
781
fn visit_array_expr(
782
&mut self,
783
elements: &[SQLExpr],
784
result_as_element: bool,
785
dtype_expr_match: Option<&Expr>,
786
) -> PolarsResult<Expr> {
787
let mut elems = self.array_expr_to_series(elements)?;
788
789
// handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
790
// (not yet as versatile as the temporal string conversions in visit_binary_op)
791
if let (Some(Expr::Column(name)), Some(schema)) =
792
(dtype_expr_match, self.active_schema.as_ref())
793
{
794
if elems.dtype() == &DataType::String {
795
if let Some(dtype) = schema.get(name) {
796
if matches!(
797
dtype,
798
DataType::Date | DataType::Time | DataType::Datetime(_, _)
799
) {
800
elems = elems.strict_cast(dtype)?;
801
}
802
}
803
}
804
}
805
806
// if we are parsing the list as an element in a series, implode.
807
// otherwise, return the series as-is.
808
let res = if result_as_element {
809
elems.implode()?.into_series()
810
} else {
811
elems
812
};
813
Ok(lit(res))
814
}
815
816
/// Visit a SQL `CAST` or `TRY_CAST` expression.
817
///
818
/// e.g. `CAST(col AS INT)`, `col::int4`, or `TRY_CAST(col AS VARCHAR)`,
819
fn visit_cast(
820
&mut self,
821
expr: &SQLExpr,
822
dtype: &SQLDataType,
823
format: &Option<CastFormat>,
824
cast_kind: &CastKind,
825
) -> PolarsResult<Expr> {
826
if format.is_some() {
827
return Err(
828
polars_err!(SQLInterface: "use of FORMAT is not currently supported in CAST"),
829
);
830
}
831
let expr = self.visit_expr(expr)?;
832
833
#[cfg(feature = "json")]
834
if dtype == &SQLDataType::JSON {
835
// @BROKEN: we cannot handle this.
836
return Ok(expr.str().json_decode(DataType::Struct(Vec::new())));
837
}
838
let polars_type = map_sql_dtype_to_polars(dtype)?;
839
Ok(match cast_kind {
840
CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type),
841
CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type),
842
})
843
}
844
845
/// Visit a SQL literal.
846
///
847
/// e.g. 1, 'foo', 1.0, NULL
848
///
849
/// See [SQLValue] and [LiteralValue] for more details
850
fn visit_literal(&self, value: &SQLValue) -> PolarsResult<Expr> {
851
// note: double-quoted strings will be parsed as identifiers, not literals
852
Ok(match value {
853
SQLValue::Boolean(b) => lit(*b),
854
SQLValue::DollarQuotedString(s) => lit(s.value.clone()),
855
#[cfg(feature = "binary_encoding")]
856
SQLValue::HexStringLiteral(x) => {
857
if x.len() % 2 != 0 {
858
polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
859
};
860
lit(hex::decode(x.clone()).unwrap())
861
},
862
SQLValue::Null => Expr::Literal(LiteralValue::untyped_null()),
863
SQLValue::Number(s, _) => {
864
// Check for existence of decimal separator dot
865
if s.contains('.') {
866
s.parse::<f64>().map(lit).map_err(|_| ())
867
} else {
868
s.parse::<i64>().map(lit).map_err(|_| ())
869
}
870
.map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
871
},
872
SQLValue::SingleQuotedByteStringLiteral(b) => {
873
// note: for PostgreSQL this represents a BIT string literal (eg: b'10101') not a BYTE string
874
// literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser-rs
875
// patterned the token name after BigQuery (where b'str' really IS a byte string)
876
bitstring_to_bytes_literal(b)?
877
},
878
SQLValue::SingleQuotedString(s) => lit(s.clone()),
879
other => {
880
polars_bail!(SQLInterface: "value {:?} is not a supported literal type", other)
881
},
882
})
883
}
884
885
/// Visit a SQL literal (like [visit_literal]), but return AnyValue instead of Expr.
886
fn visit_any_value(
887
&self,
888
value: &SQLValue,
889
op: Option<&UnaryOperator>,
890
) -> PolarsResult<AnyValue<'_>> {
891
Ok(match value {
892
SQLValue::Boolean(b) => AnyValue::Boolean(*b),
893
SQLValue::DollarQuotedString(s) => AnyValue::StringOwned(s.clone().value.into()),
894
#[cfg(feature = "binary_encoding")]
895
SQLValue::HexStringLiteral(x) => {
896
if x.len() % 2 != 0 {
897
polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
898
};
899
AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap())
900
},
901
SQLValue::Null => AnyValue::Null,
902
SQLValue::Number(s, _) => {
903
let negate = match op {
904
Some(UnaryOperator::Minus) => true,
905
// no op should be taken as plus.
906
Some(UnaryOperator::Plus) | None => false,
907
Some(op) => {
908
polars_bail!(SQLInterface: "unary op {:?} not supported for numeric SQL value", op)
909
},
910
};
911
// Check for existence of decimal separator dot
912
if s.contains('.') {
913
s.parse::<f64>()
914
.map(|n: f64| AnyValue::Float64(if negate { -n } else { n }))
915
.map_err(|_| ())
916
} else {
917
s.parse::<i64>()
918
.map(|n: i64| AnyValue::Int64(if negate { -n } else { n }))
919
.map_err(|_| ())
920
}
921
.map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
922
},
923
SQLValue::SingleQuotedByteStringLiteral(b) => {
924
// note: for PostgreSQL this represents a BIT literal (eg: b'10101') not BYTE
925
let bytes_literal = bitstring_to_bytes_literal(b)?;
926
match bytes_literal {
927
Expr::Literal(lv) if lv.extract_binary().is_some() => {
928
AnyValue::BinaryOwned(lv.extract_binary().unwrap().to_vec())
929
},
930
_ => {
931
polars_bail!(SQLInterface: "failed to parse bitstring literal: {:?}", b)
932
},
933
}
934
},
935
SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.as_str().into()),
936
other => polars_bail!(SQLInterface: "value {:?} is not currently supported", other),
937
})
938
}
939
940
/// Visit a SQL `BETWEEN` expression.
941
/// See [sqlparser::ast::Expr::Between] for more details
942
fn visit_between(
943
&mut self,
944
expr: &SQLExpr,
945
negated: bool,
946
low: &SQLExpr,
947
high: &SQLExpr,
948
) -> PolarsResult<Expr> {
949
let expr = self.visit_expr(expr)?;
950
let low = self.visit_expr(low)?;
951
let high = self.visit_expr(high)?;
952
953
let low = self.convert_temporal_strings(&expr, &low);
954
let high = self.convert_temporal_strings(&expr, &high);
955
Ok(if negated {
956
expr.clone().lt(low).or(expr.gt(high))
957
} else {
958
expr.clone().gt_eq(low).and(expr.lt_eq(high))
959
})
960
}
961
962
/// Visit a SQL `TRIM` function.
963
/// See [sqlparser::ast::Expr::Trim] for more details
964
fn visit_trim(
965
&mut self,
966
expr: &SQLExpr,
967
trim_where: &Option<TrimWhereField>,
968
trim_what: &Option<Box<SQLExpr>>,
969
trim_characters: &Option<Vec<SQLExpr>>,
970
) -> PolarsResult<Expr> {
971
if trim_characters.is_some() {
972
// TODO: allow compact snowflake/bigquery syntax?
973
return Err(polars_err!(SQLSyntax: "unsupported TRIM syntax (custom chars)"));
974
};
975
let expr = self.visit_expr(expr)?;
976
let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?;
977
let trim_what = match trim_what {
978
Some(Expr::Literal(lv)) if lv.extract_str().is_some() => {
979
Some(PlSmallStr::from_str(lv.extract_str().unwrap()))
980
},
981
None => None,
982
_ => return self.err(&expr),
983
};
984
Ok(match (trim_where, trim_what) {
985
(None | Some(TrimWhereField::Both), None) => {
986
expr.str().strip_chars(lit(LiteralValue::untyped_null()))
987
},
988
(None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)),
989
(Some(TrimWhereField::Leading), None) => expr
990
.str()
991
.strip_chars_start(lit(LiteralValue::untyped_null())),
992
(Some(TrimWhereField::Leading), Some(val)) => expr.str().strip_chars_start(lit(val)),
993
(Some(TrimWhereField::Trailing), None) => expr
994
.str()
995
.strip_chars_end(lit(LiteralValue::untyped_null())),
996
(Some(TrimWhereField::Trailing), Some(val)) => expr.str().strip_chars_end(lit(val)),
997
})
998
}
999
1000
fn visit_substring(
1001
&mut self,
1002
expr: &SQLExpr,
1003
substring_from: Option<&SQLExpr>,
1004
substring_for: Option<&SQLExpr>,
1005
) -> PolarsResult<Expr> {
1006
let e = self.visit_expr(expr)?;
1007
1008
match (substring_from, substring_for) {
1009
// SUBSTRING(expr FROM start FOR length)
1010
(Some(from_expr), Some(for_expr)) => {
1011
let start = self.visit_expr(from_expr)?;
1012
let length = self.visit_expr(for_expr)?;
1013
1014
// note: SQL is 1-indexed, so we need to adjust the offsets accordingly
1015
Ok(match (start.clone(), length.clone()) {
1016
(Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
1017
(_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
1018
polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", n)
1019
},
1020
(Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => {
1021
e.str().slice(lit(n - 1), length)
1022
},
1023
(Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => e
1024
.str()
1025
.slice(lit(0), (length + lit(n - 1)).clip_min(lit(0))),
1026
(Expr::Literal(_), _) => {
1027
polars_bail!(SQLSyntax: "invalid 'start' for SUBSTRING")
1028
},
1029
(_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
1030
polars_bail!(SQLSyntax: "invalid 'length' for SUBSTRING")
1031
},
1032
_ => {
1033
let adjusted_start = start - lit(1);
1034
when(adjusted_start.clone().lt(lit(0)))
1035
.then(e.clone().str().slice(
1036
lit(0),
1037
(length.clone() + adjusted_start.clone()).clip_min(lit(0)),
1038
))
1039
.otherwise(e.str().slice(adjusted_start, length))
1040
},
1041
})
1042
},
1043
// SUBSTRING(expr FROM start)
1044
(Some(from_expr), None) => {
1045
let start = self.visit_expr(from_expr)?;
1046
1047
Ok(match start {
1048
Expr::Literal(lv) if lv.is_null() => lit(lv),
1049
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
1050
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1051
e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null()))
1052
},
1053
Expr::Literal(_) => {
1054
polars_bail!(SQLSyntax: "invalid 'start' for SUBSTRING")
1055
},
1056
_ => e
1057
.str()
1058
.slice(start - lit(1), lit(LiteralValue::untyped_null())),
1059
})
1060
},
1061
// SUBSTRING(expr) - not valid, but handle gracefully
1062
(None, _) => {
1063
polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found 1)")
1064
},
1065
}
1066
}
1067
1068
/// Visit a SQL subquery inside an `IN` expression.
1069
fn visit_in_subquery(
1070
&mut self,
1071
expr: &SQLExpr,
1072
subquery: &Subquery,
1073
negated: bool,
1074
) -> PolarsResult<Expr> {
1075
let subquery_result = self.visit_subquery(subquery, SubqueryRestriction::SingleColumn)?;
1076
let expr = self.visit_expr(expr)?;
1077
Ok(if negated {
1078
expr.is_in(subquery_result, false).not()
1079
} else {
1080
expr.is_in(subquery_result, false)
1081
})
1082
}
1083
1084
/// Visit `CASE` control flow expression.
1085
fn visit_case_when_then(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
1086
if let SQLExpr::Case {
1087
case_token: _,
1088
end_token: _,
1089
operand,
1090
conditions,
1091
else_result,
1092
} = expr
1093
{
1094
polars_ensure!(
1095
!conditions.is_empty(),
1096
SQLSyntax: "WHEN and THEN expressions must have at least one element"
1097
);
1098
1099
let mut when_thens = conditions.iter();
1100
let first = when_thens.next();
1101
if first.is_none() {
1102
polars_bail!(SQLSyntax: "WHEN and THEN expressions must have at least one element");
1103
}
1104
let else_res = match else_result {
1105
Some(else_res) => self.visit_expr(else_res)?,
1106
None => lit(LiteralValue::untyped_null()), // ELSE clause is optional; when omitted, it is implicitly NULL
1107
};
1108
if let Some(operand_expr) = operand {
1109
let first_operand_expr = self.visit_expr(operand_expr)?;
1110
1111
let first = first.unwrap();
1112
let first_cond = first_operand_expr.eq(self.visit_expr(&first.condition)?);
1113
let first_then = self.visit_expr(&first.result)?;
1114
let expr = when(first_cond).then(first_then);
1115
let next = when_thens.next();
1116
1117
let mut when_then = if let Some(case_when) = next {
1118
let second_operand_expr = self.visit_expr(operand_expr)?;
1119
let cond = second_operand_expr.eq(self.visit_expr(&case_when.condition)?);
1120
let res = self.visit_expr(&case_when.result)?;
1121
expr.when(cond).then(res)
1122
} else {
1123
return Ok(expr.otherwise(else_res));
1124
};
1125
for case_when in when_thens {
1126
let new_operand_expr = self.visit_expr(operand_expr)?;
1127
let cond = new_operand_expr.eq(self.visit_expr(&case_when.condition)?);
1128
let res = self.visit_expr(&case_when.result)?;
1129
when_then = when_then.when(cond).then(res);
1130
}
1131
return Ok(when_then.otherwise(else_res));
1132
}
1133
1134
let first = first.unwrap();
1135
let first_cond = self.visit_expr(&first.condition)?;
1136
let first_then = self.visit_expr(&first.result)?;
1137
let expr = when(first_cond).then(first_then);
1138
let next = when_thens.next();
1139
1140
let mut when_then = if let Some(case_when) = next {
1141
let cond = self.visit_expr(&case_when.condition)?;
1142
let res = self.visit_expr(&case_when.result)?;
1143
expr.when(cond).then(res)
1144
} else {
1145
return Ok(expr.otherwise(else_res));
1146
};
1147
for case_when in when_thens {
1148
let cond = self.visit_expr(&case_when.condition)?;
1149
let res = self.visit_expr(&case_when.result)?;
1150
when_then = when_then.when(cond).then(res);
1151
}
1152
Ok(when_then.otherwise(else_res))
1153
} else {
1154
unreachable!()
1155
}
1156
}
1157
1158
fn err(&self, expr: &Expr) -> PolarsResult<Expr> {
1159
polars_bail!(SQLInterface: "expression {:?} is not currently supported", expr);
1160
}
1161
}
1162
1163
/// parse a SQL expression to a polars expression
1164
/// # Example
1165
/// ```rust
1166
/// # use polars_sql::{SQLContext, sql_expr};
1167
/// # use polars_core::prelude::*;
1168
/// # use polars_lazy::prelude::*;
1169
/// # fn main() {
1170
///
1171
/// let mut ctx = SQLContext::new();
1172
/// let df = df! {
1173
/// "a" => [1, 2, 3],
1174
/// }
1175
/// .unwrap();
1176
/// let expr = sql_expr("MAX(a)").unwrap();
1177
/// df.lazy().select(vec![expr]).collect().unwrap();
1178
/// # }
1179
/// ```
1180
pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {
1181
let mut ctx = SQLContext::new();
1182
1183
let mut parser = Parser::new(&GenericDialect);
1184
parser = parser.with_options(ParserOptions {
1185
trailing_commas: true,
1186
..Default::default()
1187
});
1188
1189
let mut ast = parser
1190
.try_with_sql(s.as_ref())
1191
.map_err(to_sql_interface_err)?;
1192
let expr = ast.parse_select_item().map_err(to_sql_interface_err)?;
1193
1194
Ok(match &expr {
1195
SelectItem::ExprWithAlias { expr, alias } => {
1196
let expr = parse_sql_expr(expr, &mut ctx, None)?;
1197
expr.alias(alias.value.as_str())
1198
},
1199
SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?,
1200
_ => polars_bail!(SQLInterface: "unable to parse '{}' as Expr", s.as_ref()),
1201
})
1202
}
1203
1204
pub(crate) fn interval_to_duration(interval: &Interval, fixed: bool) -> PolarsResult<Duration> {
1205
if interval.last_field.is_some()
1206
|| interval.leading_field.is_some()
1207
|| interval.leading_precision.is_some()
1208
|| interval.fractional_seconds_precision.is_some()
1209
{
1210
polars_bail!(SQLSyntax: "unsupported interval syntax ('{}')", interval)
1211
}
1212
let s = match &*interval.value {
1213
SQLExpr::UnaryOp { .. } => {
1214
polars_bail!(SQLSyntax: "unary ops are not valid on interval strings; found {}", interval.value)
1215
},
1216
SQLExpr::Value(ValueWithSpan {
1217
value: SQLValue::SingleQuotedString(s),
1218
..
1219
}) => Some(s),
1220
_ => None,
1221
};
1222
match s {
1223
Some(s) if s.contains('-') => {
1224
polars_bail!(SQLInterface: "minus signs are not yet supported in interval strings; found '{}'", s)
1225
},
1226
Some(s) => {
1227
// years, quarters, and months do not have a fixed duration; these
1228
// interval parts can only be used with respect to a reference point
1229
let duration = Duration::parse_interval(s);
1230
if fixed && duration.months() != 0 {
1231
polars_bail!(SQLSyntax: "fixed-duration interval cannot contain years, quarters, or months; found {}", s)
1232
};
1233
Ok(duration)
1234
},
1235
None => polars_bail!(SQLSyntax: "invalid interval {:?}", interval),
1236
}
1237
}
1238
1239
pub(crate) fn parse_sql_expr(
1240
expr: &SQLExpr,
1241
ctx: &mut SQLContext,
1242
active_schema: Option<&Schema>,
1243
) -> PolarsResult<Expr> {
1244
let mut visitor = SQLExprVisitor { ctx, active_schema };
1245
visitor.visit_expr(expr)
1246
}
1247
1248
pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Series> {
1249
match expr {
1250
SQLExpr::Array(arr) => {
1251
let mut visitor = SQLExprVisitor {
1252
ctx,
1253
active_schema: None,
1254
};
1255
visitor.array_expr_to_series(arr.elem.as_slice())
1256
},
1257
_ => polars_bail!(SQLSyntax: "Expected array expression, found {:?}", expr),
1258
}
1259
}
1260
1261
pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
1262
let field = match field {
1263
// handle 'DATE_PART' and all valid abbreviations/alternates
1264
DateTimeField::Custom(Ident { value, .. }) => {
1265
let value = value.to_ascii_lowercase();
1266
match value.as_str() {
1267
"millennium" | "millennia" => &DateTimeField::Millennium,
1268
"century" | "centuries" => &DateTimeField::Century,
1269
"decade" | "decades" => &DateTimeField::Decade,
1270
"isoyear" => &DateTimeField::Isoyear,
1271
"year" | "years" | "y" => &DateTimeField::Year,
1272
"quarter" | "quarters" => &DateTimeField::Quarter,
1273
"month" | "months" | "mon" | "mons" => &DateTimeField::Month,
1274
"dayofyear" | "doy" => &DateTimeField::DayOfYear,
1275
"dayofweek" | "dow" => &DateTimeField::DayOfWeek,
1276
"isoweek" | "week" | "weeks" => &DateTimeField::IsoWeek,
1277
"isodow" => &DateTimeField::Isodow,
1278
"day" | "days" | "d" => &DateTimeField::Day,
1279
"hour" | "hours" | "h" => &DateTimeField::Hour,
1280
"minute" | "minutes" | "mins" | "min" | "m" => &DateTimeField::Minute,
1281
"second" | "seconds" | "sec" | "secs" | "s" => &DateTimeField::Second,
1282
"millisecond" | "milliseconds" | "ms" => &DateTimeField::Millisecond,
1283
"microsecond" | "microseconds" | "us" => &DateTimeField::Microsecond,
1284
"nanosecond" | "nanoseconds" | "ns" => &DateTimeField::Nanosecond,
1285
#[cfg(feature = "timezones")]
1286
"timezone" => &DateTimeField::Timezone,
1287
"time" => &DateTimeField::Time,
1288
"epoch" => &DateTimeField::Epoch,
1289
_ => {
1290
polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", value)
1291
},
1292
}
1293
},
1294
_ => field,
1295
};
1296
Ok(match field {
1297
DateTimeField::Millennium => expr.dt().millennium(),
1298
DateTimeField::Century => expr.dt().century(),
1299
DateTimeField::Decade => expr.dt().year() / typed_lit(10i32),
1300
DateTimeField::Isoyear => expr.dt().iso_year(),
1301
DateTimeField::Year | DateTimeField::Years => expr.dt().year(),
1302
DateTimeField::Quarter => expr.dt().quarter(),
1303
DateTimeField::Month | DateTimeField::Months => expr.dt().month(),
1304
DateTimeField::Week(weekday) => {
1305
if weekday.is_some() {
1306
polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1307
}
1308
expr.dt().week()
1309
},
1310
DateTimeField::IsoWeek | DateTimeField::Weeks => expr.dt().week(),
1311
DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(),
1312
DateTimeField::DayOfWeek | DateTimeField::Dow => {
1313
let w = expr.dt().weekday();
1314
when(w.clone().eq(typed_lit(7i8)))
1315
.then(typed_lit(0i8))
1316
.otherwise(w)
1317
},
1318
DateTimeField::Isodow => expr.dt().weekday(),
1319
DateTimeField::Day | DateTimeField::Days => expr.dt().day(),
1320
DateTimeField::Hour | DateTimeField::Hours => expr.dt().hour(),
1321
DateTimeField::Minute | DateTimeField::Minutes => expr.dt().minute(),
1322
DateTimeField::Second | DateTimeField::Seconds => expr.dt().second(),
1323
DateTimeField::Millisecond | DateTimeField::Milliseconds => {
1324
(expr.clone().dt().second() * typed_lit(1_000f64))
1325
+ expr.dt().nanosecond().div(typed_lit(1_000_000f64))
1326
},
1327
DateTimeField::Microsecond | DateTimeField::Microseconds => {
1328
(expr.clone().dt().second() * typed_lit(1_000_000f64))
1329
+ expr.dt().nanosecond().div(typed_lit(1_000f64))
1330
},
1331
DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
1332
(expr.clone().dt().second() * typed_lit(1_000_000_000f64)) + expr.dt().nanosecond()
1333
},
1334
DateTimeField::Time => expr.dt().time(),
1335
#[cfg(feature = "timezones")]
1336
DateTimeField::Timezone => expr.dt().base_utc_offset().dt().total_seconds(false),
1337
DateTimeField::Epoch => {
1338
expr.clone()
1339
.dt()
1340
.timestamp(TimeUnit::Nanoseconds)
1341
.div(typed_lit(1_000_000_000i64))
1342
+ expr.dt().nanosecond().div(typed_lit(1_000_000_000f64))
1343
},
1344
_ => {
1345
polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1346
},
1347
})
1348
}
1349
1350
/// Allow an expression that represents a 1-indexed parameter to
1351
/// be adjusted from 1-indexed (SQL) to 0-indexed (Rust/Polars)
1352
pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
1353
match idx {
1354
Expr::Literal(sc) if sc.is_null() => lit(LiteralValue::untyped_null()),
1355
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => {
1356
if null_if_zero {
1357
lit(LiteralValue::untyped_null())
1358
} else {
1359
idx
1360
}
1361
},
1362
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n < 0 => idx,
1363
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => lit(n - 1),
1364
// TODO: when 'saturating_sub' is available, should be able
1365
// to streamline the when/then/otherwise block below -
1366
_ => when(idx.clone().gt(lit(0)))
1367
.then(idx.clone() - lit(1))
1368
.otherwise(if null_if_zero {
1369
when(idx.clone().eq(lit(0)))
1370
.then(lit(LiteralValue::untyped_null()))
1371
.otherwise(idx.clone())
1372
} else {
1373
idx.clone()
1374
}),
1375
}
1376
}
1377
1378
fn resolve_column<'a>(
1379
ctx: &'a mut SQLContext,
1380
ident_root: &'a Ident,
1381
name: &'a str,
1382
dtype: &'a DataType,
1383
) -> PolarsResult<(Expr, Option<&'a DataType>)> {
1384
let resolved = ctx.resolve_name(&ident_root.value, name);
1385
let resolved = resolved.as_str();
1386
Ok((
1387
if name != resolved {
1388
col(resolved).alias(name)
1389
} else {
1390
col(name)
1391
},
1392
Some(dtype),
1393
))
1394
}
1395
1396
pub(crate) fn resolve_compound_identifier(
1397
ctx: &mut SQLContext,
1398
idents: &[Ident],
1399
active_schema: Option<&Schema>,
1400
) -> PolarsResult<Vec<Expr>> {
1401
// inference priority: table > struct > column
1402
let ident_root = &idents[0];
1403
let mut remaining_idents = idents.iter().skip(1);
1404
let mut lf = ctx.get_table_from_current_scope(&ident_root.value);
1405
1406
// get schema from table (or the active/default schema)
1407
let schema = if let Some(ref mut lf) = lf {
1408
lf.schema_with_arenas(&mut ctx.lp_arena, &mut ctx.expr_arena)?
1409
} else {
1410
Arc::new(active_schema.cloned().unwrap_or_default())
1411
};
1412
1413
// handle simple/unqualified column reference with no schema
1414
if lf.is_none() && schema.is_empty() {
1415
let (mut column, mut dtype): (Expr, Option<&DataType>) =
1416
(col(ident_root.value.as_str()), None);
1417
1418
// traverse the remaining struct field path (if any)
1419
for ident in remaining_idents {
1420
let name = ident.value.as_str();
1421
match dtype {
1422
Some(DataType::Struct(fields)) if name == "*" => {
1423
return Ok(fields
1424
.iter()
1425
.map(|fld| column.clone().struct_().field_by_name(&fld.name))
1426
.collect());
1427
},
1428
Some(DataType::Struct(fields)) => {
1429
dtype = fields
1430
.iter()
1431
.find(|fld| fld.name == name)
1432
.map(|fld| &fld.dtype);
1433
},
1434
Some(dtype) if name == "*" => {
1435
polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1436
},
1437
_ => dtype = None,
1438
}
1439
column = column.struct_().field_by_name(name);
1440
}
1441
return Ok(vec![column]);
1442
}
1443
1444
let name = &remaining_idents.next().unwrap().value;
1445
1446
// handle "table.*" wildcard expansion
1447
if lf.is_some() && name == "*" {
1448
return schema
1449
.iter_names_and_dtypes()
1450
.map(|(name, dtype)| resolve_column(ctx, ident_root, name, dtype).map(|(expr, _)| expr))
1451
.collect();
1452
}
1453
1454
// resolve column/struct reference
1455
let col_dtype: PolarsResult<(Expr, Option<&DataType>)> =
1456
match (lf.is_none(), schema.get(&ident_root.value)) {
1457
// root is a column/struct in schema (no table)
1458
(true, Some(dtype)) => {
1459
remaining_idents = idents.iter().skip(1);
1460
Ok((col(ident_root.value.as_str()), Some(dtype)))
1461
},
1462
// root is not in schema and no table found
1463
(true, None) => {
1464
polars_bail!(
1465
SQLInterface: "no table or struct column named '{}' found",
1466
ident_root
1467
)
1468
},
1469
// root is a table, resolve column from table schema
1470
(false, _) => {
1471
if let Some((_, col_name, dtype)) = schema.get_full(name) {
1472
resolve_column(ctx, ident_root, col_name, dtype)
1473
} else {
1474
polars_bail!(
1475
SQLInterface: "no column named '{}' found in table '{}'",
1476
name, ident_root
1477
)
1478
}
1479
},
1480
};
1481
1482
// additional ident levels index into struct fields (eg: "df.col.field.nested_field")
1483
let (mut column, mut dtype) = col_dtype?;
1484
for ident in remaining_idents {
1485
let name = ident.value.as_str();
1486
match dtype {
1487
Some(DataType::Struct(fields)) if name == "*" => {
1488
return Ok(fields
1489
.iter()
1490
.map(|fld| column.clone().struct_().field_by_name(&fld.name))
1491
.collect());
1492
},
1493
Some(DataType::Struct(fields)) => {
1494
dtype = fields
1495
.iter()
1496
.find(|fld| fld.name == name)
1497
.map(|fld| &fld.dtype);
1498
},
1499
Some(dtype) if name == "*" => {
1500
polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1501
},
1502
_ => {
1503
dtype = None;
1504
},
1505
}
1506
column = column.struct_().field_by_name(name);
1507
}
1508
Ok(vec![column])
1509
}
1510
1511