Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-sql/src/sql_visitors.rs
8353 views
1
//! SQLVisitor helper implementations for traversing SQL AST expressions.
2
//!
3
//! This module provides visitor implementations used throughout the SQL interface
4
//! to analyze and check SQL expressions for various properties.
5
6
use std::ops::ControlFlow;
7
8
use polars_core::prelude::*;
9
use sqlparser::ast::{Expr as SQLExpr, ObjectName, Query, SetExpr, Visit, Visitor as SQLVisitor};
10
use sqlparser::keywords::ALL_KEYWORDS;
11
12
// ---------------------------------------------------------------------------
13
// FindTableIdentifier
14
// ---------------------------------------------------------------------------
15
16
/// Visitor that checks if an expression tree contains a reference to a specific table.
17
pub(crate) struct FindTableIdentifier<'a> {
18
table_name: &'a str,
19
found: bool,
20
}
21
22
impl<'a> FindTableIdentifier<'a> {
23
fn new(table_name: &'a str) -> Self {
24
Self {
25
table_name,
26
found: false,
27
}
28
}
29
}
30
31
impl<'a> SQLVisitor for FindTableIdentifier<'a> {
32
type Break = ();
33
34
fn pre_visit_expr(&mut self, expr: &SQLExpr) -> ControlFlow<Self::Break> {
35
if let SQLExpr::CompoundIdentifier(idents) = expr {
36
if idents.len() >= 2 && idents[0].value.as_str() == self.table_name {
37
self.found = true; // return immediately on first match
38
return ControlFlow::Break(());
39
}
40
}
41
ControlFlow::Continue(())
42
}
43
}
44
45
/// Check if a SQL expression contains a reference to a specific table.
46
pub(crate) fn expr_refers_to_table(expr: &SQLExpr, table_name: &str) -> bool {
47
let mut table_finder = FindTableIdentifier::new(table_name);
48
let _ = expr.visit(&mut table_finder);
49
table_finder.found
50
}
51
52
// ---------------------------------------------------------------------------
53
// QualifyExpression
54
// ---------------------------------------------------------------------------
55
56
/// Visitor used to check a SQL expression used in a QUALIFY clause.
57
/// (Confirms window functions are present and collects column refs in one pass).
58
pub(crate) struct QualifyExpression {
59
has_window_functions: bool,
60
column_refs: PlHashSet<String>,
61
}
62
63
impl QualifyExpression {
64
fn new() -> Self {
65
Self {
66
has_window_functions: false,
67
column_refs: PlHashSet::new(),
68
}
69
}
70
71
pub(crate) fn analyze(expr: &SQLExpr) -> (bool, PlHashSet<String>) {
72
let mut analyzer = Self::new();
73
let _ = expr.visit(&mut analyzer);
74
(analyzer.has_window_functions, analyzer.column_refs)
75
}
76
}
77
78
impl SQLVisitor for QualifyExpression {
79
type Break = ();
80
81
fn pre_visit_expr(&mut self, expr: &SQLExpr) -> ControlFlow<Self::Break> {
82
match expr {
83
SQLExpr::Function(func) if func.over.is_some() => {
84
self.has_window_functions = true;
85
},
86
SQLExpr::Identifier(ident) => {
87
self.column_refs.insert(ident.value.clone());
88
},
89
SQLExpr::CompoundIdentifier(idents) if !idents.is_empty() => {
90
self.column_refs
91
.insert(idents.last().unwrap().value.clone());
92
},
93
_ => {},
94
}
95
ControlFlow::Continue(())
96
}
97
}
98
99
// ---------------------------------------------------------------------------
100
// AmbiguousColumnVisitor
101
// ---------------------------------------------------------------------------
102
103
/// Format an identifier, quoting only if necessary (or `force` is true).
104
fn maybe_quote(s: &str, force: bool) -> String {
105
let needs_quoting = force
106
|| s.is_empty()
107
|| s.starts_with(|c: char| c.is_ascii_digit())
108
|| !s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
109
|| ALL_KEYWORDS.contains(&s.to_ascii_uppercase().as_str());
110
if needs_quoting {
111
format!("\"{s}\"")
112
} else {
113
s.to_string()
114
}
115
}
116
117
/// Visitor that checks for unqualified references to columns that exist in
118
/// multiple tables (columns appearing in a USING clause are excluded from
119
/// the check as they are implicitly coalesced).
120
struct AmbiguousColumnVisitor<'a> {
121
joined_aliases: &'a PlHashMap<String, PlHashMap<String, String>>,
122
base_table_name: &'a str,
123
using_cols: &'a PlHashSet<String>,
124
}
125
126
impl SQLVisitor for AmbiguousColumnVisitor<'_> {
127
type Break = PolarsError;
128
129
fn pre_visit_expr(&mut self, expr: &SQLExpr) -> ControlFlow<Self::Break> {
130
if let SQLExpr::Identifier(ident) = expr {
131
let col = &ident.value;
132
if self.using_cols.contains(col) {
133
return ControlFlow::Continue(());
134
}
135
let mut tables: Vec<_> = self
136
.joined_aliases
137
.iter()
138
.filter_map(|(t, cols)| cols.contains_key(col).then_some(t.as_str()))
139
.collect();
140
141
if !tables.is_empty() {
142
tables.push(self.base_table_name);
143
tables.sort();
144
let col_hint = maybe_quote(col, false);
145
let hints = tables
146
.iter()
147
.map(|t| format!("{}.{}", maybe_quote(t, false), col_hint));
148
return ControlFlow::Break(polars_err!(
149
SQLInterface: "ambiguous reference to column {} (use one of: {})",
150
maybe_quote(col, true), hints.collect::<Vec<_>>().join(", ")
151
));
152
}
153
}
154
ControlFlow::Continue(())
155
}
156
}
157
158
/// Check a SQL expression for unqualified references to columns that
159
/// exist in multiple tables (columns appearing in a USING clause are
160
/// excluded from the check as they are implicitly coalesced).
161
pub(crate) fn check_for_ambiguous_column_refs(
162
expr: &SQLExpr,
163
joined_aliases: &PlHashMap<String, PlHashMap<String, String>>,
164
base_table_name: &str,
165
using_cols: &PlHashSet<String>,
166
) -> PolarsResult<()> {
167
match expr.visit(&mut AmbiguousColumnVisitor {
168
joined_aliases,
169
base_table_name,
170
using_cols,
171
}) {
172
ControlFlow::Break(err) => Err(err),
173
ControlFlow::Continue(()) => Ok(()),
174
}
175
}
176
177
// ---------------------------------------------------------------------------
178
// TableIdentifierCollector
179
// ---------------------------------------------------------------------------
180
181
/// Visitor that collects all table identifiers referenced in a SQL query.
182
#[derive(Default)]
183
pub(crate) struct TableIdentifierCollector {
184
pub(crate) tables: Vec<String>,
185
pub(crate) include_schema: bool,
186
}
187
188
impl TableIdentifierCollector {
189
pub(crate) fn collect_from_set_expr(&mut self, set_expr: &SetExpr) {
190
// Recursively collect table identifiers from SetExpr nodes
191
match set_expr {
192
SetExpr::Table(tbl) => {
193
self.tables.extend(if self.include_schema {
194
match (&tbl.schema_name, &tbl.table_name) {
195
(Some(schema), Some(table)) => Some(format!("{schema}.{table}")),
196
(None, Some(table)) => Some(table.clone()),
197
_ => None,
198
}
199
} else {
200
tbl.table_name.clone()
201
});
202
},
203
SetExpr::SetOperation { left, right, .. } => {
204
self.collect_from_set_expr(left);
205
self.collect_from_set_expr(right);
206
},
207
SetExpr::Query(query) => self.collect_from_set_expr(&query.body),
208
_ => {},
209
}
210
}
211
}
212
213
impl SQLVisitor for TableIdentifierCollector {
214
type Break = ();
215
216
fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
217
// Collect from SetExpr nodes in the query body
218
self.collect_from_set_expr(&query.body);
219
ControlFlow::Continue(())
220
}
221
222
fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
223
// Table relation (eg: appearing in FROM clause)
224
self.tables.extend(if self.include_schema {
225
let parts: Vec<_> = relation
226
.0
227
.iter()
228
.filter_map(|p| p.as_ident().map(|i| i.value.as_str()))
229
.collect();
230
(!parts.is_empty()).then(|| parts.join("."))
231
} else {
232
relation
233
.0
234
.last()
235
.and_then(|p| p.as_ident())
236
.map(|i| i.value.clone())
237
});
238
ControlFlow::Continue(())
239
}
240
}
241
242
// ---------------------------------------------------------------------------
243
// WindowFunctionFinder
244
// ---------------------------------------------------------------------------
245
246
/// Visitor that checks if a SQL expression contains explicit window functions.
247
/// Uses early-exit for efficiency when only the boolean result is needed.
248
struct WindowFunctionFinder;
249
250
impl SQLVisitor for WindowFunctionFinder {
251
type Break = ();
252
253
fn pre_visit_expr(&mut self, expr: &SQLExpr) -> ControlFlow<()> {
254
if matches!(expr, SQLExpr::Function(f) if f.over.is_some()) {
255
ControlFlow::Break(())
256
} else {
257
ControlFlow::Continue(())
258
}
259
}
260
}
261
262
/// Check if a SQL expression contains explicit window functions.
263
pub(crate) fn expr_has_window_functions(expr: &SQLExpr) -> bool {
264
expr.visit(&mut WindowFunctionFinder).is_break()
265
}
266
267