Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/dsl/plan.rs
8450 views
1
use std::fmt;
2
use std::io::{Read, Write};
3
use std::sync::{Arc, Mutex};
4
5
use polars_utils::arena::Node;
6
#[cfg(feature = "serde")]
7
use polars_utils::pl_serialize;
8
use polars_utils::unique_id::UniqueId;
9
use recursive::recursive;
10
#[cfg(feature = "serde")]
11
use serde::{Deserialize, Serialize};
12
13
use super::*;
14
15
// DSL format version in a form of (Major, Minor).
16
//
17
// It is no longer needed to increment this. We use the schema hashes to check for compatibility.
18
//
19
// Only increment if you need to make a breaking change that doesn't change the schema hashes.
20
pub const DSL_VERSION: (u16, u16) = (24, 0);
21
const DSL_MAGIC_BYTES: &[u8] = b"DSL_VERSION";
22
23
const DSL_SCHEMA_HASH: SchemaHash<'static> = SchemaHash::from_hash_file();
24
25
#[derive(Debug, strum_macros::IntoStaticStr)]
26
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
27
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
28
pub enum DslPlan {
29
#[cfg(feature = "python")]
30
PythonScan {
31
options: crate::dsl::python_dsl::PythonOptionsDsl,
32
},
33
/// Filter on a boolean mask
34
Filter {
35
input: Arc<DslPlan>,
36
predicate: Expr,
37
},
38
/// Cache the input at this point in the LP
39
Cache {
40
input: Arc<DslPlan>,
41
id: UniqueId,
42
},
43
Scan {
44
sources: ScanSources,
45
unified_scan_args: Box<UnifiedScanArgs>,
46
scan_type: Box<FileScanDsl>,
47
/// Local use cases often repeatedly collect the same `LazyFrame` (e.g. in interactive notebook use-cases),
48
/// so we cache the IR conversion here, as the path expansion can be quite slow (especially for cloud paths).
49
/// We don't have the arena, as this is always a source node.
50
#[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
51
cached_ir: Arc<Mutex<Option<IR>>>,
52
},
53
// we keep track of the projection and selection as it is cheaper to first project and then filter
54
/// In memory DataFrame
55
DataFrameScan {
56
df: Arc<DataFrame>,
57
schema: SchemaRef,
58
},
59
/// Polars' `select` operation, this can mean projection, but also full data access.
60
Select {
61
expr: Vec<Expr>,
62
input: Arc<DslPlan>,
63
options: ProjectionOptions,
64
},
65
/// Groupby aggregation
66
GroupBy {
67
input: Arc<DslPlan>,
68
keys: Vec<Expr>,
69
predicates: Vec<Expr>,
70
aggs: Vec<Expr>,
71
maintain_order: bool,
72
options: Arc<GroupbyOptions>,
73
apply: Option<(PlanCallback<DataFrame, DataFrame>, SchemaRef)>,
74
},
75
/// Join operation
76
Join {
77
input_left: Arc<DslPlan>,
78
input_right: Arc<DslPlan>,
79
// Invariant: left_on and right_on are equal length.
80
left_on: Vec<Expr>,
81
right_on: Vec<Expr>,
82
// Invariant: Either left_on/right_on or predicates is set (non-empty).
83
predicates: Vec<Expr>,
84
options: Arc<JoinOptions>,
85
},
86
/// Adding columns to the table without a Join
87
HStack {
88
input: Arc<DslPlan>,
89
exprs: Vec<Expr>,
90
options: ProjectionOptions,
91
},
92
/// Match / Evolve into a schema
93
MatchToSchema {
94
input: Arc<DslPlan>,
95
/// The schema to match to.
96
///
97
/// This is also always the output schema.
98
match_schema: SchemaRef,
99
100
per_column: Arc<[MatchToSchemaPerColumn]>,
101
102
extra_columns: ExtraColumnsPolicy,
103
},
104
PipeWithSchema {
105
input: Arc<[DslPlan]>,
106
callback: PlanCallback<(Vec<DslPlan>, Vec<SchemaRef>), DslPlan>,
107
},
108
#[cfg(feature = "pivot")]
109
Pivot {
110
input: Arc<DslPlan>,
111
on: Selector,
112
on_columns: Arc<DataFrame>,
113
index: Selector,
114
values: Selector,
115
agg: Expr,
116
maintain_order: bool,
117
separator: PlSmallStr,
118
},
119
/// Remove duplicates from the table
120
Distinct {
121
input: Arc<DslPlan>,
122
options: DistinctOptionsDSL,
123
},
124
/// Sort the table
125
Sort {
126
input: Arc<DslPlan>,
127
by_column: Vec<Expr>,
128
slice: Option<(i64, usize)>,
129
sort_options: SortMultipleOptions,
130
},
131
/// Slice the table
132
Slice {
133
input: Arc<DslPlan>,
134
offset: i64,
135
len: IdxSize,
136
},
137
/// A (User Defined) Function
138
MapFunction {
139
input: Arc<DslPlan>,
140
function: DslFunction,
141
},
142
/// Vertical concatenation
143
Union {
144
inputs: Vec<DslPlan>,
145
args: UnionArgs,
146
},
147
/// Horizontal concatenation of multiple plans
148
HConcat {
149
inputs: Vec<DslPlan>,
150
options: HConcatOptions,
151
},
152
/// This allows expressions to access other tables
153
ExtContext {
154
input: Arc<DslPlan>,
155
contexts: Vec<DslPlan>,
156
},
157
Sink {
158
input: Arc<DslPlan>,
159
payload: SinkType,
160
},
161
SinkMultiple {
162
inputs: Vec<DslPlan>,
163
},
164
#[cfg(feature = "merge_sorted")]
165
MergeSorted {
166
input_left: Arc<DslPlan>,
167
input_right: Arc<DslPlan>,
168
key: PlSmallStr,
169
},
170
IR {
171
// Keep the original Dsl around as we need that for serialization.
172
dsl: Arc<DslPlan>,
173
version: u32,
174
#[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
175
node: Option<Node>,
176
},
177
}
178
179
impl Clone for DslPlan {
180
// Autogenerated by rust-analyzer, don't care about it looking nice, it just
181
// calls clone on every member of every enum variant.
182
#[rustfmt::skip]
183
#[allow(clippy::clone_on_copy)]
184
#[recursive]
185
fn clone(&self) -> Self {
186
match self {
187
#[cfg(feature = "python")]
188
Self::PythonScan { options } => Self::PythonScan { options: options.clone() },
189
Self::Filter { input, predicate } => Self::Filter { input: input.clone(), predicate: predicate.clone() },
190
Self::Cache { input, id } => Self::Cache { input: input.clone(), id: *id },
191
Self::Scan { sources, unified_scan_args, scan_type, cached_ir } => Self::Scan { sources: sources.clone(), unified_scan_args: unified_scan_args.clone(), scan_type: scan_type.clone(), cached_ir: cached_ir.clone() },
192
Self::DataFrameScan { df, schema, } => Self::DataFrameScan { df: df.clone(), schema: schema.clone(), },
193
Self::Select { expr, input, options } => Self::Select { expr: expr.clone(), input: input.clone(), options: options.clone() },
194
Self::GroupBy { input, keys, predicates, aggs, apply, maintain_order, options } => Self::GroupBy { input: input.clone(), keys: keys.clone(), predicates: predicates.clone(), aggs: aggs.clone(), apply: apply.clone(), maintain_order: maintain_order.clone(), options: options.clone() },
195
Self::Join { input_left, input_right, left_on, right_on, predicates, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone(), predicates: predicates.clone() },
196
Self::HStack { input, exprs, options } => Self::HStack { input: input.clone(), exprs: exprs.clone(), options: options.clone() },
197
Self::MatchToSchema { input, match_schema, per_column, extra_columns } => Self::MatchToSchema { input: input.clone(), match_schema: match_schema.clone(), per_column: per_column.clone(), extra_columns: *extra_columns },
198
Self::PipeWithSchema { input, callback } => Self::PipeWithSchema { input: input.clone(), callback: callback.clone() },
199
Self::Distinct { input, options } => Self::Distinct { input: input.clone(), options: options.clone() },
200
Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() },
201
Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() },
202
Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() },
203
Self::Union { inputs, args} => Self::Union { inputs: inputs.clone(), args: args.clone() },
204
Self::HConcat { inputs, options } => Self::HConcat { inputs: inputs.clone(), options: options.clone() },
205
Self::ExtContext { input, contexts, } => Self::ExtContext { input: input.clone(), contexts: contexts.clone() },
206
Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() },
207
Self::SinkMultiple { inputs } => Self::SinkMultiple { inputs: inputs.clone() },
208
#[cfg(feature = "pivot")]
209
Self::Pivot { input, on, on_columns, index, values, agg, separator, maintain_order } => Self::Pivot { input: input.clone(), on: on.clone(), on_columns: on_columns.clone(), index: index.clone(), values: values.clone(), agg: agg.clone(), separator: separator.clone(), maintain_order: *maintain_order },
210
#[cfg(feature = "merge_sorted")]
211
Self::MergeSorted { input_left, input_right, key } => Self::MergeSorted { input_left: input_left.clone(), input_right: input_right.clone(), key: key.clone() },
212
Self::IR {node, dsl, version} => Self::IR {node: *node, dsl: dsl.clone(), version: *version},
213
}
214
}
215
}
216
217
impl Default for DslPlan {
218
fn default() -> Self {
219
let df = DataFrame::empty();
220
let schema = df.schema().clone();
221
DslPlan::DataFrameScan {
222
df: Arc::new(df),
223
schema,
224
}
225
}
226
}
227
228
#[derive(Default, Clone, Copy)]
229
pub struct PlanSerializationContext {
230
pub use_cloudpickle: bool,
231
}
232
233
impl DslPlan {
234
pub fn describe(&self) -> PolarsResult<String> {
235
Ok(self.clone().to_alp()?.describe())
236
}
237
238
pub fn describe_tree_format(&self) -> PolarsResult<String> {
239
Ok(self.clone().to_alp()?.describe_tree_format())
240
}
241
242
pub fn display(&self) -> PolarsResult<impl fmt::Display> {
243
struct DslPlanDisplay(IRPlan);
244
impl fmt::Display for DslPlanDisplay {
245
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246
fmt::Display::fmt(&self.0.as_ref().display(), f)
247
}
248
}
249
Ok(DslPlanDisplay(self.clone().to_alp()?))
250
}
251
252
pub fn to_alp(self) -> PolarsResult<IRPlan> {
253
let mut lp_arena = Arena::with_capacity(16);
254
let mut expr_arena = Arena::with_capacity(16);
255
256
let node = to_alp(
257
self,
258
&mut expr_arena,
259
&mut lp_arena,
260
&mut OptFlags::default(),
261
)?;
262
let plan = IRPlan::new(node, lp_arena, expr_arena);
263
264
Ok(plan)
265
}
266
267
#[cfg(feature = "serde")]
268
pub fn serialize_versioned<W: Write>(
269
&self,
270
mut writer: W,
271
ctx: PlanSerializationContext,
272
) -> PolarsResult<()> {
273
let le_major = DSL_VERSION.0.to_le_bytes();
274
let le_minor = DSL_VERSION.1.to_le_bytes();
275
276
// @GB:
277
// This is absolute horrendous but serde does not allow for state to passed along with the
278
// serialization so there is no proper way to do this except replace serde.
279
polars_utils::pl_serialize::USE_CLOUDPICKLE.set(ctx.use_cloudpickle);
280
281
writer.write_all(DSL_MAGIC_BYTES)?;
282
writer.write_all(&le_major)?;
283
writer.write_all(&le_minor)?;
284
writer.write_all(DSL_SCHEMA_HASH.as_bytes())?;
285
let serializable_plan = serializable_plan::SerializableDslPlan::from(self);
286
pl_serialize::serialize_dsl(writer, &serializable_plan)
287
.map_err(|e| polars_err!(ComputeError: "serialization failed\n\nerror: {e}"))
288
}
289
290
#[cfg(feature = "serde")]
291
pub fn deserialize_versioned<R: Read>(mut reader: R) -> PolarsResult<Self> {
292
const MAGIC_LEN: usize = DSL_MAGIC_BYTES.len();
293
let mut version_magic = [0u8; MAGIC_LEN + 4];
294
reader
295
.read_exact(&mut version_magic)
296
.map_err(|e| polars_err!(ComputeError: "failed to read incoming DSL_VERSION: {e}"))?;
297
298
if &version_magic[..MAGIC_LEN] != DSL_MAGIC_BYTES {
299
polars_bail!(ComputeError: "dsl magic bytes not found")
300
}
301
302
let major = u16::from_le_bytes(version_magic[MAGIC_LEN..MAGIC_LEN + 2].try_into().unwrap());
303
let minor = u16::from_le_bytes(
304
version_magic[MAGIC_LEN + 2..MAGIC_LEN + 4]
305
.try_into()
306
.unwrap(),
307
);
308
309
const MAJOR: u16 = DSL_VERSION.0;
310
const MINOR: u16 = DSL_VERSION.1;
311
312
if polars_core::config::verbose() {
313
eprintln!(
314
"incoming DSL_VERSION: {major}.{minor}, deserializer DSL_VERSION: {MAJOR}.{MINOR}"
315
);
316
}
317
318
if major != MAJOR {
319
polars_bail!(ComputeError:
320
"deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is not compatible with this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}",
321
"error: can't deserialize DSL with a different major version"
322
);
323
}
324
325
if minor > MINOR {
326
polars_bail!(ComputeError:
327
"deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is not compatible with this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}",
328
"error: can't deserialize DSL with a higher minor version"
329
);
330
}
331
332
let mut schema_hash = [0_u8; SCHEMA_HASH_LEN];
333
reader.read_exact(&mut schema_hash).map_err(
334
|e| polars_err!(ComputeError: "failed to read incoming DSL_SCHEMA_HASH: {e}"),
335
)?;
336
337
let incoming_hash = SchemaHash::new(&schema_hash).ok_or_else(
338
|| polars_err!(ComputeError: "failed to read incoming DSL schema hash, not a valid hex string")
339
)?;
340
341
if polars_core::config::verbose() {
342
eprintln!(
343
"incoming DSL_SCHEMA_HASH: {incoming_hash}, deserializer DSL_SCHEMA_HASH: {DSL_SCHEMA_HASH}"
344
);
345
}
346
347
if std::env::var("POLARS_SKIP_DSL_HASH_VERIFICATION").as_deref() != Ok("1")
348
&& incoming_hash != DSL_SCHEMA_HASH
349
{
350
polars_bail!(ComputeError:
351
"deserialization failed\n\ngiven DSL_SCHEMA_HASH: {incoming_hash} is not compatible with this Polars version which uses DSL_SCHEMA_HASH: {DSL_SCHEMA_HASH}\n{}",
352
"error: can't deserialize DSL with incompatible schema"
353
);
354
}
355
356
let serializable_plan: serializable_plan::SerializableDslPlan =
357
pl_serialize::deserialize_dsl(reader)
358
.map_err(|e| polars_err!(ComputeError: "deserialization failed\n\nerror: {e}"))?;
359
(&serializable_plan).try_into()
360
}
361
362
#[cfg(feature = "dsl-schema")]
363
pub fn dsl_schema() -> schemars::Schema {
364
use schemars::Schema;
365
use schemars::generate::SchemaSettings;
366
use schemars::transform::{Transform, transform_subschemas};
367
368
#[derive(Clone, Copy, Debug)]
369
struct MyTransform;
370
371
impl Transform for MyTransform {
372
fn transform(&mut self, schema: &mut Schema) {
373
// Remove descriptions auto-generated from doc comments
374
schema.remove("description");
375
376
transform_subschemas(self, schema);
377
}
378
}
379
380
let mut schema = SchemaSettings::default()
381
.with_transform(MyTransform)
382
.into_generator()
383
.into_root_schema_for::<DslPlan>();
384
385
// Add the DSL schema hash as a top level field
386
schema.insert("hash".into(), DSL_SCHEMA_HASH.to_string().into());
387
388
schema
389
}
390
}
391
392
const SCHEMA_HASH_LEN: usize = 64;
393
394
struct SchemaHash<'a>(&'a str);
395
396
impl SchemaHash<'static> {
397
const fn from_hash_file() -> Self {
398
// Generated by build.rs
399
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/dsl-schema.sha256"));
400
Self::new(bytes).expect("not a valid hex string")
401
}
402
}
403
404
impl<'a> SchemaHash<'a> {
405
const fn new(bytes: &'a [u8; SCHEMA_HASH_LEN]) -> Option<Self> {
406
let mut i = 0;
407
while i < bytes.len() {
408
if !bytes[i].is_ascii_hexdigit() {
409
return None;
410
};
411
i += 1;
412
}
413
match str::from_utf8(bytes) {
414
Ok(hash) => Some(Self(hash)),
415
Err(_) => unreachable!(),
416
}
417
}
418
419
fn as_bytes(&self) -> &'a [u8; SCHEMA_HASH_LEN] {
420
self.0.as_bytes().try_into().unwrap()
421
}
422
}
423
424
impl PartialEq for SchemaHash<'_> {
425
fn eq(&self, other: &Self) -> bool {
426
self.0.eq_ignore_ascii_case(other.0)
427
}
428
}
429
430
impl std::fmt::Display for SchemaHash<'_> {
431
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
432
write!(f, "{}", self.0)
433
}
434
}
435
436