Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-sql/src/sql_expr.rs
6939 views
1
//! Expressions that are supported by the Polars SQL interface.
2
//!
3
//! This is useful for syntax highlighting
4
//!
5
//! This module defines:
6
//! - all Polars SQL keywords [`all_keywords`]
7
//! - all of polars SQL functions [`all_functions`]
8
9
use std::fmt::Display;
10
use std::ops::Div;
11
12
use polars_core::prelude::*;
13
use polars_lazy::prelude::*;
14
use polars_plan::plans::DynLiteralValue;
15
use polars_plan::prelude::typed_lit;
16
use polars_time::Duration;
17
use rand::Rng;
18
use rand::distr::Alphanumeric;
19
#[cfg(feature = "serde")]
20
use serde::{Deserialize, Serialize};
21
use sqlparser::ast::{
22
BinaryOperator as SQLBinaryOperator, CastFormat, CastKind, DataType as SQLDataType,
23
DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, Interval, Query as Subquery,
24
SelectItem, Subscript, TimezoneInfo, TrimWhereField, UnaryOperator, Value as SQLValue,
25
};
26
use sqlparser::dialect::GenericDialect;
27
use sqlparser::parser::{Parser, ParserOptions};
28
29
use crate::SQLContext;
30
use crate::functions::SQLFunctionVisitor;
31
use crate::types::{
32
bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars,
33
};
34
35
#[inline]
36
#[cold]
37
#[must_use]
38
/// Convert a Display-able error to PolarsError::SQLInterface
39
pub fn to_sql_interface_err(err: impl Display) -> PolarsError {
40
PolarsError::SQLInterface(err.to_string().into())
41
}
42
43
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
44
#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
45
/// Categorises the type of (allowed) subquery constraint
46
pub enum SubqueryRestriction {
47
/// Subquery must return a single column
48
SingleColumn,
49
// SingleRow,
50
// SingleValue,
51
// Any
52
}
53
54
/// Recursively walks a SQL Expr to create a polars Expr
55
pub(crate) struct SQLExprVisitor<'a> {
56
ctx: &'a mut SQLContext,
57
active_schema: Option<&'a Schema>,
58
}
59
60
impl SQLExprVisitor<'_> {
61
fn array_expr_to_series(&mut self, elements: &[SQLExpr]) -> PolarsResult<Series> {
62
let mut array_elements = Vec::with_capacity(elements.len());
63
for e in elements {
64
let val = match e {
65
SQLExpr::Value(v) => self.visit_any_value(v, None),
66
SQLExpr::UnaryOp { op, expr } => match expr.as_ref() {
67
SQLExpr::Value(v) => self.visit_any_value(v, Some(op)),
68
_ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
69
},
70
SQLExpr::Array(values) => {
71
let srs = self.array_expr_to_series(&values.elem)?;
72
Ok(AnyValue::List(srs))
73
},
74
_ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
75
}?
76
.into_static();
77
array_elements.push(val);
78
}
79
Series::from_any_values(PlSmallStr::EMPTY, &array_elements, true)
80
}
81
82
fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
83
match expr {
84
SQLExpr::AllOp {
85
left,
86
compare_op,
87
right,
88
} => self.visit_all(left, compare_op, right),
89
SQLExpr::AnyOp {
90
left,
91
compare_op,
92
right,
93
is_some: _,
94
} => self.visit_any(left, compare_op, right),
95
SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None),
96
SQLExpr::Between {
97
expr,
98
negated,
99
low,
100
high,
101
} => self.visit_between(expr, *negated, low, high),
102
SQLExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right),
103
SQLExpr::Cast {
104
kind,
105
expr,
106
data_type,
107
format,
108
} => self.visit_cast(expr, data_type, format, kind),
109
SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
110
SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
111
SQLExpr::Extract {
112
field,
113
syntax: _,
114
expr,
115
} => parse_extract_date_part(self.visit_expr(expr)?, field),
116
SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
117
SQLExpr::Function(function) => self.visit_function(function),
118
SQLExpr::Identifier(ident) => self.visit_identifier(ident),
119
SQLExpr::InList {
120
expr,
121
list,
122
negated,
123
} => {
124
let expr = self.visit_expr(expr)?;
125
let elems = self.visit_array_expr(list, true, Some(&expr))?;
126
let is_in = expr.is_in(elems, false);
127
Ok(if *negated { is_in.not() } else { is_in })
128
},
129
SQLExpr::InSubquery {
130
expr,
131
subquery,
132
negated,
133
} => self.visit_in_subquery(expr, subquery, *negated),
134
SQLExpr::Interval(interval) => Ok(lit(interval_to_duration(interval, true)?)),
135
SQLExpr::IsDistinctFrom(e1, e2) => {
136
Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?))
137
},
138
SQLExpr::IsFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false))),
139
SQLExpr::IsNotDistinctFrom(e1, e2) => {
140
Ok(self.visit_expr(e1)?.eq_missing(self.visit_expr(e2)?))
141
},
142
SQLExpr::IsNotFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false)).not()),
143
SQLExpr::IsNotNull(expr) => Ok(self.visit_expr(expr)?.is_not_null()),
144
SQLExpr::IsNotTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true)).not()),
145
SQLExpr::IsNull(expr) => Ok(self.visit_expr(expr)?.is_null()),
146
SQLExpr::IsTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true))),
147
SQLExpr::Like {
148
negated,
149
any,
150
expr,
151
pattern,
152
escape_char,
153
} => {
154
if *any {
155
polars_bail!(SQLSyntax: "LIKE ANY is not a supported syntax")
156
}
157
self.visit_like(*negated, expr, pattern, escape_char, false)
158
},
159
SQLExpr::ILike {
160
negated,
161
any,
162
expr,
163
pattern,
164
escape_char,
165
} => {
166
if *any {
167
polars_bail!(SQLSyntax: "ILIKE ANY is not a supported syntax")
168
}
169
self.visit_like(*negated, expr, pattern, escape_char, true)
170
},
171
SQLExpr::Nested(expr) => self.visit_expr(expr),
172
SQLExpr::Position { expr, r#in } => Ok(
173
// note: SQL is 1-indexed
174
(self
175
.visit_expr(r#in)?
176
.str()
177
.find(self.visit_expr(expr)?, true)
178
+ typed_lit(1u32))
179
.fill_null(typed_lit(0u32)),
180
),
181
SQLExpr::RLike {
182
// note: parses both RLIKE and REGEXP
183
negated,
184
expr,
185
pattern,
186
regexp: _,
187
} => {
188
let matches = self
189
.visit_expr(expr)?
190
.str()
191
.contains(self.visit_expr(pattern)?, true);
192
Ok(if *negated { matches.not() } else { matches })
193
},
194
SQLExpr::Subscript { expr, subscript } => self.visit_subscript(expr, subscript),
195
SQLExpr::Subquery(_) => polars_bail!(SQLInterface: "unexpected subquery"),
196
SQLExpr::Trim {
197
expr,
198
trim_where,
199
trim_what,
200
trim_characters,
201
} => self.visit_trim(expr, trim_where, trim_what, trim_characters),
202
SQLExpr::TypedString { data_type, value } => match data_type {
203
SQLDataType::Date => {
204
if is_iso_date(value) {
205
Ok(lit(value.as_str()).cast(DataType::Date))
206
} else {
207
polars_bail!(SQLSyntax: "invalid DATE literal '{}'", value)
208
}
209
},
210
SQLDataType::Time(None, TimezoneInfo::None) => {
211
if is_iso_time(value) {
212
Ok(lit(value.as_str()).str().to_time(StrptimeOptions {
213
strict: true,
214
..Default::default()
215
}))
216
} else {
217
polars_bail!(SQLSyntax: "invalid TIME literal '{}'", value)
218
}
219
},
220
SQLDataType::Timestamp(None, TimezoneInfo::None) | SQLDataType::Datetime(None) => {
221
if is_iso_datetime(value) {
222
Ok(lit(value.as_str()).str().to_datetime(
223
None,
224
None,
225
StrptimeOptions {
226
strict: true,
227
..Default::default()
228
},
229
lit("latest"),
230
))
231
} else {
232
let fn_name = match data_type {
233
SQLDataType::Timestamp(_, _) => "TIMESTAMP",
234
SQLDataType::Datetime(_) => "DATETIME",
235
_ => unreachable!(),
236
};
237
polars_bail!(SQLSyntax: "invalid {} literal '{}'", fn_name, value)
238
}
239
},
240
_ => {
241
polars_bail!(SQLInterface: "typed literal should be one of DATE, DATETIME, TIME, or TIMESTAMP (found {})", data_type)
242
},
243
},
244
SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
245
SQLExpr::Value(value) => self.visit_literal(value),
246
SQLExpr::Wildcard(_) => Ok(all().as_expr()),
247
e @ SQLExpr::Case { .. } => self.visit_case_when_then(e),
248
other => {
249
polars_bail!(SQLInterface: "expression {:?} is not currently supported", other)
250
},
251
}
252
}
253
254
fn visit_subquery(
255
&mut self,
256
subquery: &Subquery,
257
restriction: SubqueryRestriction,
258
) -> PolarsResult<Expr> {
259
if subquery.with.is_some() {
260
polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE 'WITH' clause");
261
}
262
let mut lf = self.ctx.execute_query_no_ctes(subquery)?;
263
let schema = self.ctx.get_frame_schema(&mut lf)?;
264
265
if restriction == SubqueryRestriction::SingleColumn {
266
if schema.len() != 1 {
267
polars_bail!(SQLSyntax: "SQL subquery returns more than one column");
268
}
269
let rand_string: String = rand::rng()
270
.sample_iter(&Alphanumeric)
271
.take(16)
272
.map(char::from)
273
.collect();
274
275
let schema_entry = schema.get_at_index(0);
276
if let Some((old_name, _)) = schema_entry {
277
let new_name = String::from(old_name.as_str()) + rand_string.as_str();
278
lf = lf.rename([old_name.to_string()], [new_name.clone()], true);
279
return Ok(Expr::SubPlan(
280
SpecialEq::new(Arc::new(lf.logical_plan)),
281
vec![new_name],
282
));
283
}
284
};
285
polars_bail!(SQLInterface: "subquery type not supported");
286
}
287
288
/// Visit a single SQL identifier.
289
///
290
/// e.g. column
291
fn visit_identifier(&self, ident: &Ident) -> PolarsResult<Expr> {
292
Ok(col(ident.value.as_str()))
293
}
294
295
/// Visit a compound SQL identifier
296
///
297
/// e.g. tbl.column, struct.field, tbl.struct.field (inc. nested struct fields)
298
fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult<Expr> {
299
Ok(resolve_compound_identifier(self.ctx, idents, self.active_schema)?[0].clone())
300
}
301
302
fn visit_like(
303
&mut self,
304
negated: bool,
305
expr: &SQLExpr,
306
pattern: &SQLExpr,
307
escape_char: &Option<String>,
308
case_insensitive: bool,
309
) -> PolarsResult<Expr> {
310
if escape_char.is_some() {
311
polars_bail!(SQLInterface: "ESCAPE char for LIKE/ILIKE is not currently supported; found '{}'", escape_char.clone().unwrap());
312
}
313
let pat = match self.visit_expr(pattern) {
314
Ok(Expr::Literal(lv)) if lv.extract_str().is_some() => {
315
PlSmallStr::from_str(lv.extract_str().unwrap())
316
},
317
_ => {
318
polars_bail!(SQLSyntax: "LIKE/ILIKE pattern must be a string literal; found {}", pattern)
319
},
320
};
321
if pat.is_empty() || (!case_insensitive && pat.chars().all(|c| !matches!(c, '%' | '_'))) {
322
// empty string or other exact literal match (eg: no wildcard chars)
323
let op = if negated {
324
SQLBinaryOperator::NotEq
325
} else {
326
SQLBinaryOperator::Eq
327
};
328
self.visit_binary_op(expr, &op, pattern)
329
} else {
330
// create regex from pattern containing SQL wildcard chars ('%' => '.*', '_' => '.')
331
let mut rx = regex::escape(pat.as_str())
332
.replace('%', ".*")
333
.replace('_', ".");
334
335
rx = format!(
336
"^{}{}$",
337
if case_insensitive { "(?is)" } else { "(?s)" },
338
rx
339
);
340
341
let expr = self.visit_expr(expr)?;
342
let matches = expr.str().contains(lit(rx), true);
343
Ok(if negated { matches.not() } else { matches })
344
}
345
}
346
347
fn visit_subscript(&mut self, expr: &SQLExpr, subscript: &Subscript) -> PolarsResult<Expr> {
348
let expr = self.visit_expr(expr)?;
349
Ok(match subscript {
350
Subscript::Index { index } => {
351
let idx = adjust_one_indexed_param(self.visit_expr(index)?, true);
352
expr.list().get(idx, true)
353
},
354
Subscript::Slice { .. } => {
355
polars_bail!(SQLSyntax: "array slice syntax is not currently supported")
356
},
357
})
358
}
359
360
/// Handle implicit temporal string comparisons.
361
///
362
/// eg: clauses such as -
363
/// "dt >= '2024-04-30'"
364
/// "dt = '2077-10-10'::date"
365
/// "dtm::date = '2077-10-10'
366
fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr {
367
if let (Some(name), Some(s), expr_dtype) = match (left, right) {
368
// identify "col <op> string" expressions
369
(Expr::Column(name), Expr::Literal(lv)) if lv.extract_str().is_some() => {
370
(Some(name.clone()), Some(lv.extract_str().unwrap()), None)
371
},
372
// identify "CAST(expr AS type) <op> string" and/or "expr::type <op> string" expressions
373
(Expr::Cast { expr, dtype, .. }, Expr::Literal(lv)) if lv.extract_str().is_some() => {
374
let s = lv.extract_str().unwrap();
375
match &**expr {
376
Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)),
377
_ => (None, Some(s), Some(dtype)),
378
}
379
},
380
_ => (None, None, None),
381
} {
382
if expr_dtype.is_none() && self.active_schema.is_none() {
383
right.clone()
384
} else {
385
let left_dtype = expr_dtype.map_or_else(
386
|| {
387
self.active_schema
388
.as_ref()
389
.and_then(|schema| schema.get(&name))
390
},
391
|dt| dt.as_literal(),
392
);
393
match left_dtype {
394
Some(DataType::Time) if is_iso_time(s) => {
395
right.clone().str().to_time(StrptimeOptions {
396
strict: true,
397
..Default::default()
398
})
399
},
400
Some(DataType::Date) if is_iso_date(s) => {
401
right.clone().str().to_date(StrptimeOptions {
402
strict: true,
403
..Default::default()
404
})
405
},
406
Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => {
407
if s.len() == 10 {
408
// handle upcast from ISO date string (10 chars) to datetime
409
lit(format!("{s}T00:00:00"))
410
} else {
411
lit(s.replacen(' ', "T", 1))
412
}
413
.str()
414
.to_datetime(
415
Some(*tu),
416
tz.clone(),
417
StrptimeOptions {
418
strict: true,
419
..Default::default()
420
},
421
lit("latest"),
422
)
423
},
424
_ => right.clone(),
425
}
426
}
427
} else {
428
right.clone()
429
}
430
}
431
432
fn struct_field_access_expr(
433
&mut self,
434
expr: &Expr,
435
path: &str,
436
infer_index: bool,
437
) -> PolarsResult<Expr> {
438
let path_elems = if path.starts_with('{') && path.ends_with('}') {
439
path.trim_matches(|c| c == '{' || c == '}')
440
} else {
441
path
442
}
443
.split(',');
444
445
let mut expr = expr.clone();
446
for p in path_elems {
447
let p = p.trim();
448
expr = if infer_index {
449
match p.parse::<i64>() {
450
Ok(idx) => expr.list().get(lit(idx), true),
451
Err(_) => expr.struct_().field_by_name(p),
452
}
453
} else {
454
expr.struct_().field_by_name(p)
455
}
456
}
457
Ok(expr)
458
}
459
460
/// Visit a SQL binary operator.
461
///
462
/// e.g. "column + 1", "column1 <= column2"
463
fn visit_binary_op(
464
&mut self,
465
left: &SQLExpr,
466
op: &SQLBinaryOperator,
467
right: &SQLExpr,
468
) -> PolarsResult<Expr> {
469
// need special handling for interval offsets and comparisons
470
let (lhs, mut rhs) = match (left, op, right) {
471
(_, SQLBinaryOperator::Minus, SQLExpr::Interval(v)) => {
472
let duration = interval_to_duration(v, false)?;
473
return Ok(self
474
.visit_expr(left)?
475
.dt()
476
.offset_by(lit(format!("-{duration}"))));
477
},
478
(_, SQLBinaryOperator::Plus, SQLExpr::Interval(v)) => {
479
let duration = interval_to_duration(v, false)?;
480
return Ok(self
481
.visit_expr(left)?
482
.dt()
483
.offset_by(lit(format!("{duration}"))));
484
},
485
(SQLExpr::Interval(v1), _, SQLExpr::Interval(v2)) => {
486
// shortcut interval comparison evaluation (-> bool)
487
let d1 = interval_to_duration(v1, false)?;
488
let d2 = interval_to_duration(v2, false)?;
489
let res = match op {
490
SQLBinaryOperator::Gt => Ok(lit(d1 > d2)),
491
SQLBinaryOperator::Lt => Ok(lit(d1 < d2)),
492
SQLBinaryOperator::GtEq => Ok(lit(d1 >= d2)),
493
SQLBinaryOperator::LtEq => Ok(lit(d1 <= d2)),
494
SQLBinaryOperator::NotEq => Ok(lit(d1 != d2)),
495
SQLBinaryOperator::Eq | SQLBinaryOperator::Spaceship => Ok(lit(d1 == d2)),
496
_ => polars_bail!(SQLInterface: "invalid interval comparison operator"),
497
};
498
if res.is_ok() {
499
return res;
500
}
501
(self.visit_expr(left)?, self.visit_expr(right)?)
502
},
503
_ => (self.visit_expr(left)?, self.visit_expr(right)?),
504
};
505
rhs = self.convert_temporal_strings(&lhs, &rhs);
506
507
Ok(match op {
508
// ----
509
// Bitwise operators
510
// ----
511
SQLBinaryOperator::BitwiseAnd => lhs.and(rhs), // "x & y"
512
SQLBinaryOperator::BitwiseOr => lhs.or(rhs), // "x | y"
513
SQLBinaryOperator::Xor => lhs.xor(rhs), // "x XOR y"
514
515
// ----
516
// General operators
517
// ----
518
SQLBinaryOperator::And => lhs.and(rhs), // "x AND y"
519
SQLBinaryOperator::Divide => lhs / rhs, // "x / y"
520
SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64), // "x // y"
521
SQLBinaryOperator::Eq => lhs.eq(rhs), // "x = y"
522
SQLBinaryOperator::Gt => lhs.gt(rhs), // "x > y"
523
SQLBinaryOperator::GtEq => lhs.gt_eq(rhs), // "x >= y"
524
SQLBinaryOperator::Lt => lhs.lt(rhs), // "x < y"
525
SQLBinaryOperator::LtEq => lhs.lt_eq(rhs), // "x <= y"
526
SQLBinaryOperator::Minus => lhs - rhs, // "x - y"
527
SQLBinaryOperator::Modulo => lhs % rhs, // "x % y"
528
SQLBinaryOperator::Multiply => lhs * rhs, // "x * y"
529
SQLBinaryOperator::NotEq => lhs.eq(rhs).not(), // "x != y"
530
SQLBinaryOperator::Or => lhs.or(rhs), // "x OR y"
531
SQLBinaryOperator::Plus => lhs + rhs, // "x + y"
532
SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs), // "x <=> y"
533
SQLBinaryOperator::StringConcat => { // "x || y"
534
lhs.cast(DataType::String) + rhs.cast(DataType::String)
535
},
536
SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs), // "x ^@ y"
537
// ----
538
// Regular expression operators
539
// ----
540
SQLBinaryOperator::PGRegexMatch => match rhs { // "x ~ y"
541
Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true),
542
_ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs),
543
},
544
SQLBinaryOperator::PGRegexNotMatch => match rhs { // "x !~ y"
545
Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true).not(),
546
_ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs),
547
},
548
SQLBinaryOperator::PGRegexIMatch => match rhs { // "x ~* y"
549
Expr::Literal(ref lv) if lv.extract_str().is_some() => {
550
let pat = lv.extract_str().unwrap();
551
lhs.str().contains(lit(format!("(?i){pat}")), true)
552
},
553
_ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs),
554
},
555
SQLBinaryOperator::PGRegexNotIMatch => match rhs { // "x !~* y"
556
Expr::Literal(ref lv) if lv.extract_str().is_some() => {
557
let pat = lv.extract_str().unwrap();
558
lhs.str().contains(lit(format!("(?i){pat}")), true).not()
559
},
560
_ => {
561
polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", rhs)
562
},
563
},
564
// ----
565
// LIKE/ILIKE operators
566
// ----
567
SQLBinaryOperator::PGLikeMatch // "x ~~ y"
568
| SQLBinaryOperator::PGNotLikeMatch // "x !~~ y"
569
| SQLBinaryOperator::PGILikeMatch // "x ~~* y"
570
| SQLBinaryOperator::PGNotILikeMatch => { // "x !~~* y"
571
let expr = if matches!(
572
op,
573
SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch
574
) {
575
SQLExpr::Like {
576
negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch),
577
any: false,
578
expr: Box::new(left.clone()),
579
pattern: Box::new(right.clone()),
580
escape_char: None,
581
}
582
} else {
583
SQLExpr::ILike {
584
negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch),
585
any: false,
586
expr: Box::new(left.clone()),
587
pattern: Box::new(right.clone()),
588
escape_char: None,
589
}
590
};
591
self.visit_expr(&expr)?
592
},
593
// ----
594
// JSON/Struct field access operators
595
// ----
596
SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs { // "x -> y", "x ->> y"
597
Expr::Literal(lv) if lv.extract_str().is_some() => {
598
let path = lv.extract_str().unwrap();
599
let mut expr = self.struct_field_access_expr(&lhs, path, false)?;
600
if let SQLBinaryOperator::LongArrow = op {
601
expr = expr.cast(DataType::String);
602
}
603
expr
604
},
605
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(idx))) => {
606
let mut expr = self.struct_field_access_expr(&lhs, &idx.to_string(), true)?;
607
if let SQLBinaryOperator::LongArrow = op {
608
expr = expr.cast(DataType::String);
609
}
610
expr
611
},
612
_ => {
613
polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right)
614
},
615
},
616
SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => { // "x #> y", "x #>> y"
617
match rhs {
618
Expr::Literal(lv) if lv.extract_str().is_some() => {
619
let path = lv.extract_str().unwrap();
620
let mut expr = self.struct_field_access_expr(&lhs, path, true)?;
621
if let SQLBinaryOperator::HashLongArrow = op {
622
expr = expr.cast(DataType::String);
623
}
624
expr
625
},
626
_ => {
627
polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", rhs)
628
}
629
}
630
},
631
other => {
632
polars_bail!(SQLInterface: "operator {:?} is not currently supported", other)
633
},
634
})
635
}
636
637
/// Visit a SQL unary operator.
638
///
639
/// e.g. +column or -column
640
fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
641
let expr = self.visit_expr(expr)?;
642
Ok(match (op, expr.clone()) {
643
// simplify the parse tree by special-casing common unary +/- ops
644
(UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
645
lit(n)
646
},
647
(UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
648
lit(n)
649
},
650
(UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
651
lit(-n)
652
},
653
(UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
654
lit(-n)
655
},
656
// general case
657
(UnaryOperator::Plus, _) => lit(0) + expr,
658
(UnaryOperator::Minus, _) => lit(0) - expr,
659
(UnaryOperator::Not, _) => expr.not(),
660
other => polars_bail!(SQLInterface: "unary operator {:?} is not supported", other),
661
})
662
}
663
664
/// Visit a SQL function.
665
///
666
/// e.g. SUM(column) or COUNT(*)
667
///
668
/// See [SQLFunctionVisitor] for more details
669
fn visit_function(&mut self, function: &SQLFunction) -> PolarsResult<Expr> {
670
let mut visitor = SQLFunctionVisitor {
671
func: function,
672
ctx: self.ctx,
673
active_schema: self.active_schema,
674
};
675
visitor.visit_function()
676
}
677
678
/// Visit a SQL `ALL` expression.
679
///
680
/// e.g. `a > ALL(y)`
681
fn visit_all(
682
&mut self,
683
left: &SQLExpr,
684
compare_op: &SQLBinaryOperator,
685
right: &SQLExpr,
686
) -> PolarsResult<Expr> {
687
let left = self.visit_expr(left)?;
688
let right = self.visit_expr(right)?;
689
690
match compare_op {
691
SQLBinaryOperator::Gt => Ok(left.gt(right.max())),
692
SQLBinaryOperator::Lt => Ok(left.lt(right.min())),
693
SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.max())),
694
SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.min())),
695
SQLBinaryOperator::Eq => polars_bail!(SQLSyntax: "ALL cannot be used with ="),
696
SQLBinaryOperator::NotEq => polars_bail!(SQLSyntax: "ALL cannot be used with !="),
697
_ => polars_bail!(SQLInterface: "invalid comparison operator"),
698
}
699
}
700
701
/// Visit a SQL `ANY` expression.
702
///
703
/// e.g. `a != ANY(y)`
704
fn visit_any(
705
&mut self,
706
left: &SQLExpr,
707
compare_op: &SQLBinaryOperator,
708
right: &SQLExpr,
709
) -> PolarsResult<Expr> {
710
let left = self.visit_expr(left)?;
711
let right = self.visit_expr(right)?;
712
713
match compare_op {
714
SQLBinaryOperator::Gt => Ok(left.gt(right.min())),
715
SQLBinaryOperator::Lt => Ok(left.lt(right.max())),
716
SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.min())),
717
SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.max())),
718
SQLBinaryOperator::Eq => Ok(left.is_in(right, false)),
719
SQLBinaryOperator::NotEq => Ok(left.is_in(right, false).not()),
720
_ => polars_bail!(SQLInterface: "invalid comparison operator"),
721
}
722
}
723
724
/// Visit a SQL `ARRAY` list (including `IN` values).
725
fn visit_array_expr(
726
&mut self,
727
elements: &[SQLExpr],
728
result_as_element: bool,
729
dtype_expr_match: Option<&Expr>,
730
) -> PolarsResult<Expr> {
731
let mut elems = self.array_expr_to_series(elements)?;
732
733
// handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
734
// (not yet as versatile as the temporal string conversions in visit_binary_op)
735
if let (Some(Expr::Column(name)), Some(schema)) =
736
(dtype_expr_match, self.active_schema.as_ref())
737
{
738
if elems.dtype() == &DataType::String {
739
if let Some(dtype) = schema.get(name) {
740
if matches!(
741
dtype,
742
DataType::Date | DataType::Time | DataType::Datetime(_, _)
743
) {
744
elems = elems.strict_cast(dtype)?;
745
}
746
}
747
}
748
}
749
750
// if we are parsing the list as an element in a series, implode.
751
// otherwise, return the series as-is.
752
let res = if result_as_element {
753
elems.implode()?.into_series()
754
} else {
755
elems
756
};
757
Ok(lit(res))
758
}
759
760
/// Visit a SQL `CAST` or `TRY_CAST` expression.
761
///
762
/// e.g. `CAST(col AS INT)`, `col::int4`, or `TRY_CAST(col AS VARCHAR)`,
763
fn visit_cast(
764
&mut self,
765
expr: &SQLExpr,
766
dtype: &SQLDataType,
767
format: &Option<CastFormat>,
768
cast_kind: &CastKind,
769
) -> PolarsResult<Expr> {
770
if format.is_some() {
771
return Err(
772
polars_err!(SQLInterface: "use of FORMAT is not currently supported in CAST"),
773
);
774
}
775
let expr = self.visit_expr(expr)?;
776
777
#[cfg(feature = "json")]
778
if dtype == &SQLDataType::JSON {
779
// @BROKEN: we cannot handle this.
780
return Ok(expr.str().json_decode(DataType::Struct(Vec::new())));
781
}
782
let polars_type = map_sql_dtype_to_polars(dtype)?;
783
Ok(match cast_kind {
784
CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type),
785
CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type),
786
})
787
}
788
789
/// Visit a SQL literal.
790
///
791
/// e.g. 1, 'foo', 1.0, NULL
792
///
793
/// See [SQLValue] and [LiteralValue] for more details
794
fn visit_literal(&self, value: &SQLValue) -> PolarsResult<Expr> {
795
// note: double-quoted strings will be parsed as identifiers, not literals
796
Ok(match value {
797
SQLValue::Boolean(b) => lit(*b),
798
SQLValue::DollarQuotedString(s) => lit(s.value.clone()),
799
#[cfg(feature = "binary_encoding")]
800
SQLValue::HexStringLiteral(x) => {
801
if x.len() % 2 != 0 {
802
polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
803
};
804
lit(hex::decode(x.clone()).unwrap())
805
},
806
SQLValue::Null => Expr::Literal(LiteralValue::untyped_null()),
807
SQLValue::Number(s, _) => {
808
// Check for existence of decimal separator dot
809
if s.contains('.') {
810
s.parse::<f64>().map(lit).map_err(|_| ())
811
} else {
812
s.parse::<i64>().map(lit).map_err(|_| ())
813
}
814
.map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
815
},
816
SQLValue::SingleQuotedByteStringLiteral(b) => {
817
// note: for PostgreSQL this represents a BIT string literal (eg: b'10101') not a BYTE string
818
// literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser-rs
819
// patterned the token name after BigQuery (where b'str' really IS a byte string)
820
bitstring_to_bytes_literal(b)?
821
},
822
SQLValue::SingleQuotedString(s) => lit(s.clone()),
823
other => {
824
polars_bail!(SQLInterface: "value {:?} is not a supported literal type", other)
825
},
826
})
827
}
828
829
/// Visit a SQL literal (like [visit_literal]), but return AnyValue instead of Expr.
830
fn visit_any_value(
831
&self,
832
value: &SQLValue,
833
op: Option<&UnaryOperator>,
834
) -> PolarsResult<AnyValue<'_>> {
835
Ok(match value {
836
SQLValue::Boolean(b) => AnyValue::Boolean(*b),
837
SQLValue::DollarQuotedString(s) => AnyValue::StringOwned(s.clone().value.into()),
838
#[cfg(feature = "binary_encoding")]
839
SQLValue::HexStringLiteral(x) => {
840
if x.len() % 2 != 0 {
841
polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
842
};
843
AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap())
844
},
845
SQLValue::Null => AnyValue::Null,
846
SQLValue::Number(s, _) => {
847
let negate = match op {
848
Some(UnaryOperator::Minus) => true,
849
// no op should be taken as plus.
850
Some(UnaryOperator::Plus) | None => false,
851
Some(op) => {
852
polars_bail!(SQLInterface: "unary op {:?} not supported for numeric SQL value", op)
853
},
854
};
855
// Check for existence of decimal separator dot
856
if s.contains('.') {
857
s.parse::<f64>()
858
.map(|n: f64| AnyValue::Float64(if negate { -n } else { n }))
859
.map_err(|_| ())
860
} else {
861
s.parse::<i64>()
862
.map(|n: i64| AnyValue::Int64(if negate { -n } else { n }))
863
.map_err(|_| ())
864
}
865
.map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
866
},
867
SQLValue::SingleQuotedByteStringLiteral(b) => {
868
// note: for PostgreSQL this represents a BIT literal (eg: b'10101') not BYTE
869
let bytes_literal = bitstring_to_bytes_literal(b)?;
870
match bytes_literal {
871
Expr::Literal(lv) if lv.extract_binary().is_some() => {
872
AnyValue::BinaryOwned(lv.extract_binary().unwrap().to_vec())
873
},
874
_ => {
875
polars_bail!(SQLInterface: "failed to parse bitstring literal: {:?}", b)
876
},
877
}
878
},
879
SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.as_str().into()),
880
other => polars_bail!(SQLInterface: "value {:?} is not currently supported", other),
881
})
882
}
883
884
/// Visit a SQL `BETWEEN` expression.
885
/// See [sqlparser::ast::Expr::Between] for more details
886
fn visit_between(
887
&mut self,
888
expr: &SQLExpr,
889
negated: bool,
890
low: &SQLExpr,
891
high: &SQLExpr,
892
) -> PolarsResult<Expr> {
893
let expr = self.visit_expr(expr)?;
894
let low = self.visit_expr(low)?;
895
let high = self.visit_expr(high)?;
896
897
let low = self.convert_temporal_strings(&expr, &low);
898
let high = self.convert_temporal_strings(&expr, &high);
899
Ok(if negated {
900
expr.clone().lt(low).or(expr.gt(high))
901
} else {
902
expr.clone().gt_eq(low).and(expr.lt_eq(high))
903
})
904
}
905
906
/// Visit a SQL `TRIM` function.
907
/// See [sqlparser::ast::Expr::Trim] for more details
908
fn visit_trim(
909
&mut self,
910
expr: &SQLExpr,
911
trim_where: &Option<TrimWhereField>,
912
trim_what: &Option<Box<SQLExpr>>,
913
trim_characters: &Option<Vec<SQLExpr>>,
914
) -> PolarsResult<Expr> {
915
if trim_characters.is_some() {
916
// TODO: allow compact snowflake/bigquery syntax?
917
return Err(polars_err!(SQLSyntax: "unsupported TRIM syntax (custom chars)"));
918
};
919
let expr = self.visit_expr(expr)?;
920
let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?;
921
let trim_what = match trim_what {
922
Some(Expr::Literal(lv)) if lv.extract_str().is_some() => {
923
Some(PlSmallStr::from_str(lv.extract_str().unwrap()))
924
},
925
None => None,
926
_ => return self.err(&expr),
927
};
928
Ok(match (trim_where, trim_what) {
929
(None | Some(TrimWhereField::Both), None) => {
930
expr.str().strip_chars(lit(LiteralValue::untyped_null()))
931
},
932
(None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)),
933
(Some(TrimWhereField::Leading), None) => expr
934
.str()
935
.strip_chars_start(lit(LiteralValue::untyped_null())),
936
(Some(TrimWhereField::Leading), Some(val)) => expr.str().strip_chars_start(lit(val)),
937
(Some(TrimWhereField::Trailing), None) => expr
938
.str()
939
.strip_chars_end(lit(LiteralValue::untyped_null())),
940
(Some(TrimWhereField::Trailing), Some(val)) => expr.str().strip_chars_end(lit(val)),
941
})
942
}
943
944
/// Visit a SQL subquery inside and `IN` expression.
945
fn visit_in_subquery(
946
&mut self,
947
expr: &SQLExpr,
948
subquery: &Subquery,
949
negated: bool,
950
) -> PolarsResult<Expr> {
951
let subquery_result = self.visit_subquery(subquery, SubqueryRestriction::SingleColumn)?;
952
let expr = self.visit_expr(expr)?;
953
Ok(if negated {
954
expr.is_in(subquery_result, false).not()
955
} else {
956
expr.is_in(subquery_result, false)
957
})
958
}
959
960
/// Visit `CASE` control flow expression.
961
fn visit_case_when_then(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
962
if let SQLExpr::Case {
963
operand,
964
conditions,
965
results,
966
else_result,
967
} = expr
968
{
969
polars_ensure!(
970
conditions.len() == results.len(),
971
SQLSyntax: "WHEN and THEN expressions must have the same length"
972
);
973
polars_ensure!(
974
!conditions.is_empty(),
975
SQLSyntax: "WHEN and THEN expressions must have at least one element"
976
);
977
978
let mut when_thens = conditions.iter().zip(results.iter());
979
let first = when_thens.next();
980
if first.is_none() {
981
polars_bail!(SQLSyntax: "WHEN and THEN expressions must have at least one element");
982
}
983
let else_res = match else_result {
984
Some(else_res) => self.visit_expr(else_res)?,
985
None => lit(LiteralValue::untyped_null()), // ELSE clause is optional; when omitted, it is implicitly NULL
986
};
987
if let Some(operand_expr) = operand {
988
let first_operand_expr = self.visit_expr(operand_expr)?;
989
990
let first = first.unwrap();
991
let first_cond = first_operand_expr.eq(self.visit_expr(first.0)?);
992
let first_then = self.visit_expr(first.1)?;
993
let expr = when(first_cond).then(first_then);
994
let next = when_thens.next();
995
996
let mut when_then = if let Some((cond, res)) = next {
997
let second_operand_expr = self.visit_expr(operand_expr)?;
998
let cond = second_operand_expr.eq(self.visit_expr(cond)?);
999
let res = self.visit_expr(res)?;
1000
expr.when(cond).then(res)
1001
} else {
1002
return Ok(expr.otherwise(else_res));
1003
};
1004
for (cond, res) in when_thens {
1005
let new_operand_expr = self.visit_expr(operand_expr)?;
1006
let cond = new_operand_expr.eq(self.visit_expr(cond)?);
1007
let res = self.visit_expr(res)?;
1008
when_then = when_then.when(cond).then(res);
1009
}
1010
return Ok(when_then.otherwise(else_res));
1011
}
1012
1013
let first = first.unwrap();
1014
let first_cond = self.visit_expr(first.0)?;
1015
let first_then = self.visit_expr(first.1)?;
1016
let expr = when(first_cond).then(first_then);
1017
let next = when_thens.next();
1018
1019
let mut when_then = if let Some((cond, res)) = next {
1020
let cond = self.visit_expr(cond)?;
1021
let res = self.visit_expr(res)?;
1022
expr.when(cond).then(res)
1023
} else {
1024
return Ok(expr.otherwise(else_res));
1025
};
1026
for (cond, res) in when_thens {
1027
let cond = self.visit_expr(cond)?;
1028
let res = self.visit_expr(res)?;
1029
when_then = when_then.when(cond).then(res);
1030
}
1031
Ok(when_then.otherwise(else_res))
1032
} else {
1033
unreachable!()
1034
}
1035
}
1036
1037
fn err(&self, expr: &Expr) -> PolarsResult<Expr> {
1038
polars_bail!(SQLInterface: "expression {:?} is not currently supported", expr);
1039
}
1040
}
1041
1042
/// parse a SQL expression to a polars expression
1043
/// # Example
1044
/// ```rust
1045
/// # use polars_sql::{SQLContext, sql_expr};
1046
/// # use polars_core::prelude::*;
1047
/// # use polars_lazy::prelude::*;
1048
/// # fn main() {
1049
///
1050
/// let mut ctx = SQLContext::new();
1051
/// let df = df! {
1052
/// "a" => [1, 2, 3],
1053
/// }
1054
/// .unwrap();
1055
/// let expr = sql_expr("MAX(a)").unwrap();
1056
/// df.lazy().select(vec![expr]).collect().unwrap();
1057
/// # }
1058
/// ```
1059
pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {
1060
let mut ctx = SQLContext::new();
1061
1062
let mut parser = Parser::new(&GenericDialect);
1063
parser = parser.with_options(ParserOptions {
1064
trailing_commas: true,
1065
..Default::default()
1066
});
1067
1068
let mut ast = parser
1069
.try_with_sql(s.as_ref())
1070
.map_err(to_sql_interface_err)?;
1071
let expr = ast.parse_select_item().map_err(to_sql_interface_err)?;
1072
1073
Ok(match &expr {
1074
SelectItem::ExprWithAlias { expr, alias } => {
1075
let expr = parse_sql_expr(expr, &mut ctx, None)?;
1076
expr.alias(alias.value.as_str())
1077
},
1078
SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?,
1079
_ => polars_bail!(SQLInterface: "unable to parse '{}' as Expr", s.as_ref()),
1080
})
1081
}
1082
1083
pub(crate) fn interval_to_duration(interval: &Interval, fixed: bool) -> PolarsResult<Duration> {
1084
if interval.last_field.is_some()
1085
|| interval.leading_field.is_some()
1086
|| interval.leading_precision.is_some()
1087
|| interval.fractional_seconds_precision.is_some()
1088
{
1089
polars_bail!(SQLSyntax: "unsupported interval syntax ('{}')", interval)
1090
}
1091
let s = match &*interval.value {
1092
SQLExpr::UnaryOp { .. } => {
1093
polars_bail!(SQLSyntax: "unary ops are not valid on interval strings; found {}", interval.value)
1094
},
1095
SQLExpr::Value(SQLValue::SingleQuotedString(s)) => Some(s),
1096
_ => None,
1097
};
1098
match s {
1099
Some(s) if s.contains('-') => {
1100
polars_bail!(SQLInterface: "minus signs are not yet supported in interval strings; found '{}'", s)
1101
},
1102
Some(s) => {
1103
// years, quarters, and months do not have a fixed duration; these
1104
// interval parts can only be used with respect to a reference point
1105
let duration = Duration::parse_interval(s);
1106
if fixed && duration.months() != 0 {
1107
polars_bail!(SQLSyntax: "fixed-duration interval cannot contain years, quarters, or months; found {}", s)
1108
};
1109
Ok(duration)
1110
},
1111
None => polars_bail!(SQLSyntax: "invalid interval {:?}", interval),
1112
}
1113
}
1114
1115
pub(crate) fn parse_sql_expr(
1116
expr: &SQLExpr,
1117
ctx: &mut SQLContext,
1118
active_schema: Option<&Schema>,
1119
) -> PolarsResult<Expr> {
1120
let mut visitor = SQLExprVisitor { ctx, active_schema };
1121
visitor.visit_expr(expr)
1122
}
1123
1124
pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Series> {
1125
match expr {
1126
SQLExpr::Array(arr) => {
1127
let mut visitor = SQLExprVisitor {
1128
ctx,
1129
active_schema: None,
1130
};
1131
visitor.array_expr_to_series(arr.elem.as_slice())
1132
},
1133
_ => polars_bail!(SQLSyntax: "Expected array expression, found {:?}", expr),
1134
}
1135
}
1136
1137
pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
1138
let field = match field {
1139
// handle 'DATE_PART' and all valid abbreviations/alternates
1140
DateTimeField::Custom(Ident { value, .. }) => {
1141
let value = value.to_ascii_lowercase();
1142
match value.as_str() {
1143
"millennium" | "millennia" => &DateTimeField::Millennium,
1144
"century" | "centuries" => &DateTimeField::Century,
1145
"decade" | "decades" => &DateTimeField::Decade,
1146
"isoyear" => &DateTimeField::Isoyear,
1147
"year" | "years" | "y" => &DateTimeField::Year,
1148
"quarter" | "quarters" => &DateTimeField::Quarter,
1149
"month" | "months" | "mon" | "mons" => &DateTimeField::Month,
1150
"dayofyear" | "doy" => &DateTimeField::DayOfYear,
1151
"dayofweek" | "dow" => &DateTimeField::DayOfWeek,
1152
"isoweek" | "week" | "weeks" => &DateTimeField::IsoWeek,
1153
"isodow" => &DateTimeField::Isodow,
1154
"day" | "days" | "d" => &DateTimeField::Day,
1155
"hour" | "hours" | "h" => &DateTimeField::Hour,
1156
"minute" | "minutes" | "mins" | "min" | "m" => &DateTimeField::Minute,
1157
"second" | "seconds" | "sec" | "secs" | "s" => &DateTimeField::Second,
1158
"millisecond" | "milliseconds" | "ms" => &DateTimeField::Millisecond,
1159
"microsecond" | "microseconds" | "us" => &DateTimeField::Microsecond,
1160
"nanosecond" | "nanoseconds" | "ns" => &DateTimeField::Nanosecond,
1161
#[cfg(feature = "timezones")]
1162
"timezone" => &DateTimeField::Timezone,
1163
"time" => &DateTimeField::Time,
1164
"epoch" => &DateTimeField::Epoch,
1165
_ => {
1166
polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", value)
1167
},
1168
}
1169
},
1170
_ => field,
1171
};
1172
Ok(match field {
1173
DateTimeField::Millennium => expr.dt().millennium(),
1174
DateTimeField::Century => expr.dt().century(),
1175
DateTimeField::Decade => expr.dt().year() / typed_lit(10i32),
1176
DateTimeField::Isoyear => expr.dt().iso_year(),
1177
DateTimeField::Year => expr.dt().year(),
1178
DateTimeField::Quarter => expr.dt().quarter(),
1179
DateTimeField::Month => expr.dt().month(),
1180
DateTimeField::Week(weekday) => {
1181
if weekday.is_some() {
1182
polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1183
}
1184
expr.dt().week()
1185
},
1186
DateTimeField::IsoWeek => expr.dt().week(),
1187
DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(),
1188
DateTimeField::DayOfWeek | DateTimeField::Dow => {
1189
let w = expr.dt().weekday();
1190
when(w.clone().eq(typed_lit(7i8)))
1191
.then(typed_lit(0i8))
1192
.otherwise(w)
1193
},
1194
DateTimeField::Isodow => expr.dt().weekday(),
1195
DateTimeField::Day => expr.dt().day(),
1196
DateTimeField::Hour => expr.dt().hour(),
1197
DateTimeField::Minute => expr.dt().minute(),
1198
DateTimeField::Second => expr.dt().second(),
1199
DateTimeField::Millisecond | DateTimeField::Milliseconds => {
1200
(expr.clone().dt().second() * typed_lit(1_000f64))
1201
+ expr.dt().nanosecond().div(typed_lit(1_000_000f64))
1202
},
1203
DateTimeField::Microsecond | DateTimeField::Microseconds => {
1204
(expr.clone().dt().second() * typed_lit(1_000_000f64))
1205
+ expr.dt().nanosecond().div(typed_lit(1_000f64))
1206
},
1207
DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
1208
(expr.clone().dt().second() * typed_lit(1_000_000_000f64)) + expr.dt().nanosecond()
1209
},
1210
DateTimeField::Time => expr.dt().time(),
1211
#[cfg(feature = "timezones")]
1212
DateTimeField::Timezone => expr.dt().base_utc_offset().dt().total_seconds(),
1213
DateTimeField::Epoch => {
1214
expr.clone()
1215
.dt()
1216
.timestamp(TimeUnit::Nanoseconds)
1217
.div(typed_lit(1_000_000_000i64))
1218
+ expr.dt().nanosecond().div(typed_lit(1_000_000_000f64))
1219
},
1220
_ => {
1221
polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1222
},
1223
})
1224
}
1225
1226
/// Allow an expression that represents a 1-indexed parameter to
1227
/// be adjusted from 1-indexed (SQL) to 0-indexed (Rust/Polars)
1228
pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
1229
match idx {
1230
Expr::Literal(sc) if sc.is_null() => lit(LiteralValue::untyped_null()),
1231
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => {
1232
if null_if_zero {
1233
lit(LiteralValue::untyped_null())
1234
} else {
1235
idx
1236
}
1237
},
1238
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n < 0 => idx,
1239
Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => lit(n - 1),
1240
// TODO: when 'saturating_sub' is available, should be able
1241
// to streamline the when/then/otherwise block below -
1242
_ => when(idx.clone().gt(lit(0)))
1243
.then(idx.clone() - lit(1))
1244
.otherwise(if null_if_zero {
1245
when(idx.clone().eq(lit(0)))
1246
.then(lit(LiteralValue::untyped_null()))
1247
.otherwise(idx.clone())
1248
} else {
1249
idx.clone()
1250
}),
1251
}
1252
}
1253
1254
fn resolve_column<'a>(
1255
ctx: &'a mut SQLContext,
1256
ident_root: &'a Ident,
1257
name: &'a str,
1258
dtype: &'a DataType,
1259
) -> PolarsResult<(Expr, Option<&'a DataType>)> {
1260
let resolved = ctx.resolve_name(&ident_root.value, name);
1261
let resolved = resolved.as_str();
1262
Ok((
1263
if name != resolved {
1264
col(resolved).alias(name)
1265
} else {
1266
col(name)
1267
},
1268
Some(dtype),
1269
))
1270
}
1271
1272
pub(crate) fn resolve_compound_identifier(
1273
ctx: &mut SQLContext,
1274
idents: &[Ident],
1275
active_schema: Option<&Schema>,
1276
) -> PolarsResult<Vec<Expr>> {
1277
// inference priority: table > struct > column
1278
let ident_root = &idents[0];
1279
let mut remaining_idents = idents.iter().skip(1);
1280
let mut lf = ctx.get_table_from_current_scope(&ident_root.value);
1281
1282
let schema = if let Some(ref mut lf) = lf {
1283
lf.schema_with_arenas(&mut ctx.lp_arena, &mut ctx.expr_arena)
1284
} else {
1285
Ok(Arc::new(if let Some(active_schema) = active_schema {
1286
active_schema.clone()
1287
} else {
1288
Schema::default()
1289
}))
1290
}?;
1291
1292
let col_dtype: PolarsResult<(Expr, Option<&DataType>)> = if lf.is_none() && schema.is_empty() {
1293
Ok((col(ident_root.value.as_str()), None))
1294
} else {
1295
let name = &remaining_idents.next().unwrap().value;
1296
if lf.is_some() && name == "*" {
1297
return Ok(schema
1298
.iter_names_and_dtypes()
1299
.map(|(name, dtype)| resolve_column(ctx, ident_root, name, dtype).unwrap().0)
1300
.collect::<Vec<_>>());
1301
};
1302
let root_is_field = schema.get(&ident_root.value).is_some();
1303
if lf.is_none() && root_is_field {
1304
remaining_idents = idents.iter().skip(1);
1305
Ok((
1306
col(ident_root.value.as_str()),
1307
schema.get(&ident_root.value),
1308
))
1309
} else if lf.is_none() && !root_is_field {
1310
polars_bail!(
1311
SQLInterface: "no table or struct column named '{}' found",
1312
ident_root
1313
)
1314
} else if let Some((_, name, dtype)) = schema.get_full(name) {
1315
resolve_column(ctx, ident_root, name, dtype)
1316
} else {
1317
polars_bail!(
1318
SQLInterface: "no column named '{}' found in table '{}'",
1319
name,
1320
ident_root
1321
)
1322
}
1323
};
1324
1325
// additional ident levels index into struct fields
1326
let (mut column, mut dtype) = col_dtype?;
1327
for ident in remaining_idents {
1328
let name = ident.value.as_str();
1329
match dtype {
1330
Some(DataType::Struct(fields)) if name == "*" => {
1331
return Ok(fields
1332
.iter()
1333
.map(|fld| column.clone().struct_().field_by_name(&fld.name))
1334
.collect());
1335
},
1336
Some(DataType::Struct(fields)) => {
1337
dtype = fields
1338
.iter()
1339
.find(|fld| fld.name == name)
1340
.map(|fld| &fld.dtype);
1341
},
1342
Some(dtype) if name == "*" => {
1343
polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1344
},
1345
_ => {
1346
dtype = None;
1347
},
1348
}
1349
column = column.struct_().field_by_name(name);
1350
}
1351
Ok(vec![column])
1352
}
1353
1354