Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/dsl/plan.rs
6939 views
1
use std::fmt;
2
use std::io::{Read, Write};
3
use std::sync::{Arc, Mutex};
4
5
use polars_utils::arena::Node;
6
#[cfg(feature = "serde")]
7
use polars_utils::pl_serialize;
8
use polars_utils::unique_id::UniqueId;
9
use recursive::recursive;
10
#[cfg(feature = "serde")]
11
use serde::{Deserialize, Serialize};
12
13
use super::*;
14
15
// DSL format version in a form of (Major, Minor).
16
//
17
// It is no longer needed to increment this. We use the schema hashes to check for compatibility.
18
//
19
// Only increment if you need to make a breaking change that doesn't change the schema hashes.
20
pub const DSL_VERSION: (u16, u16) = (23, 0);
21
const DSL_MAGIC_BYTES: &[u8] = b"DSL_VERSION";
22
23
const DSL_SCHEMA_HASH: SchemaHash<'static> = SchemaHash::from_hash_file();
24
25
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
27
pub enum DslPlan {
28
#[cfg(feature = "python")]
29
PythonScan {
30
options: crate::dsl::python_dsl::PythonOptionsDsl,
31
},
32
/// Filter on a boolean mask
33
Filter {
34
input: Arc<DslPlan>,
35
predicate: Expr,
36
},
37
/// Cache the input at this point in the LP
38
Cache {
39
input: Arc<DslPlan>,
40
id: UniqueId,
41
},
42
Scan {
43
sources: ScanSources,
44
unified_scan_args: Box<UnifiedScanArgs>,
45
scan_type: Box<FileScanDsl>,
46
/// Local use cases often repeatedly collect the same `LazyFrame` (e.g. in interactive notebook use-cases),
47
/// so we cache the IR conversion here, as the path expansion can be quite slow (especially for cloud paths).
48
/// We don't have the arena, as this is always a source node.
49
#[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
50
cached_ir: Arc<Mutex<Option<IR>>>,
51
},
52
// we keep track of the projection and selection as it is cheaper to first project and then filter
53
/// In memory DataFrame
54
DataFrameScan {
55
df: Arc<DataFrame>,
56
schema: SchemaRef,
57
},
58
/// Polars' `select` operation, this can mean projection, but also full data access.
59
Select {
60
expr: Vec<Expr>,
61
input: Arc<DslPlan>,
62
options: ProjectionOptions,
63
},
64
/// Groupby aggregation
65
GroupBy {
66
input: Arc<DslPlan>,
67
keys: Vec<Expr>,
68
aggs: Vec<Expr>,
69
maintain_order: bool,
70
options: Arc<GroupbyOptions>,
71
apply: Option<(PlanCallback<DataFrame, DataFrame>, SchemaRef)>,
72
},
73
/// Join operation
74
Join {
75
input_left: Arc<DslPlan>,
76
input_right: Arc<DslPlan>,
77
// Invariant: left_on and right_on are equal length.
78
left_on: Vec<Expr>,
79
right_on: Vec<Expr>,
80
// Invariant: Either left_on/right_on or predicates is set (non-empty).
81
predicates: Vec<Expr>,
82
options: Arc<JoinOptions>,
83
},
84
/// Adding columns to the table without a Join
85
HStack {
86
input: Arc<DslPlan>,
87
exprs: Vec<Expr>,
88
options: ProjectionOptions,
89
},
90
/// Match / Evolve into a schema
91
MatchToSchema {
92
input: Arc<DslPlan>,
93
/// The schema to match to.
94
///
95
/// This is also always the output schema.
96
match_schema: SchemaRef,
97
98
per_column: Arc<[MatchToSchemaPerColumn]>,
99
100
extra_columns: ExtraColumnsPolicy,
101
},
102
PipeWithSchema {
103
input: Arc<DslPlan>,
104
callback: PlanCallback<(DslPlan, Schema), DslPlan>,
105
},
106
/// Remove duplicates from the table
107
Distinct {
108
input: Arc<DslPlan>,
109
options: DistinctOptionsDSL,
110
},
111
/// Sort the table
112
Sort {
113
input: Arc<DslPlan>,
114
by_column: Vec<Expr>,
115
slice: Option<(i64, usize)>,
116
sort_options: SortMultipleOptions,
117
},
118
/// Slice the table
119
Slice {
120
input: Arc<DslPlan>,
121
offset: i64,
122
len: IdxSize,
123
},
124
/// A (User Defined) Function
125
MapFunction {
126
input: Arc<DslPlan>,
127
function: DslFunction,
128
},
129
/// Vertical concatenation
130
Union {
131
inputs: Vec<DslPlan>,
132
args: UnionArgs,
133
},
134
/// Horizontal concatenation of multiple plans
135
HConcat {
136
inputs: Vec<DslPlan>,
137
options: HConcatOptions,
138
},
139
/// This allows expressions to access other tables
140
ExtContext {
141
input: Arc<DslPlan>,
142
contexts: Vec<DslPlan>,
143
},
144
Sink {
145
input: Arc<DslPlan>,
146
payload: SinkType,
147
},
148
SinkMultiple {
149
inputs: Vec<DslPlan>,
150
},
151
#[cfg(feature = "merge_sorted")]
152
MergeSorted {
153
input_left: Arc<DslPlan>,
154
input_right: Arc<DslPlan>,
155
key: PlSmallStr,
156
},
157
IR {
158
// Keep the original Dsl around as we need that for serialization.
159
dsl: Arc<DslPlan>,
160
version: u32,
161
#[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
162
node: Option<Node>,
163
},
164
}
165
166
impl Clone for DslPlan {
167
// Autogenerated by rust-analyzer, don't care about it looking nice, it just
168
// calls clone on every member of every enum variant.
169
#[rustfmt::skip]
170
#[allow(clippy::clone_on_copy)]
171
#[recursive]
172
fn clone(&self) -> Self {
173
match self {
174
#[cfg(feature = "python")]
175
Self::PythonScan { options } => Self::PythonScan { options: options.clone() },
176
Self::Filter { input, predicate } => Self::Filter { input: input.clone(), predicate: predicate.clone() },
177
Self::Cache { input, id } => Self::Cache { input: input.clone(), id: *id },
178
Self::Scan { sources, unified_scan_args, scan_type, cached_ir } => Self::Scan { sources: sources.clone(), unified_scan_args: unified_scan_args.clone(), scan_type: scan_type.clone(), cached_ir: cached_ir.clone() },
179
Self::DataFrameScan { df, schema, } => Self::DataFrameScan { df: df.clone(), schema: schema.clone(), },
180
Self::Select { expr, input, options } => Self::Select { expr: expr.clone(), input: input.clone(), options: options.clone() },
181
Self::GroupBy { input, keys, aggs, apply, maintain_order, options } => Self::GroupBy { input: input.clone(), keys: keys.clone(), aggs: aggs.clone(), apply: apply.clone(), maintain_order: maintain_order.clone(), options: options.clone() },
182
Self::Join { input_left, input_right, left_on, right_on, predicates, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone(), predicates: predicates.clone() },
183
Self::HStack { input, exprs, options } => Self::HStack { input: input.clone(), exprs: exprs.clone(), options: options.clone() },
184
Self::MatchToSchema { input, match_schema, per_column, extra_columns } => Self::MatchToSchema { input: input.clone(), match_schema: match_schema.clone(), per_column: per_column.clone(), extra_columns: *extra_columns },
185
Self::PipeWithSchema { input, callback } => Self::PipeWithSchema { input: input.clone(), callback: callback.clone() },
186
Self::Distinct { input, options } => Self::Distinct { input: input.clone(), options: options.clone() },
187
Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() },
188
Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() },
189
Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() },
190
Self::Union { inputs, args} => Self::Union { inputs: inputs.clone(), args: args.clone() },
191
Self::HConcat { inputs, options } => Self::HConcat { inputs: inputs.clone(), options: options.clone() },
192
Self::ExtContext { input, contexts, } => Self::ExtContext { input: input.clone(), contexts: contexts.clone() },
193
Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() },
194
Self::SinkMultiple { inputs } => Self::SinkMultiple { inputs: inputs.clone() },
195
#[cfg(feature = "merge_sorted")]
196
Self::MergeSorted { input_left, input_right, key } => Self::MergeSorted { input_left: input_left.clone(), input_right: input_right.clone(), key: key.clone() },
197
Self::IR {node, dsl, version} => Self::IR {node: *node, dsl: dsl.clone(), version: *version},
198
}
199
}
200
}
201
202
impl Default for DslPlan {
203
fn default() -> Self {
204
let df = DataFrame::empty();
205
let schema = df.schema().clone();
206
DslPlan::DataFrameScan {
207
df: Arc::new(df),
208
schema,
209
}
210
}
211
}
212
213
#[derive(Default, Clone, Copy)]
214
pub struct PlanSerializationContext {
215
pub use_cloudpickle: bool,
216
}
217
218
impl DslPlan {
219
pub fn describe(&self) -> PolarsResult<String> {
220
Ok(self.clone().to_alp()?.describe())
221
}
222
223
pub fn describe_tree_format(&self) -> PolarsResult<String> {
224
Ok(self.clone().to_alp()?.describe_tree_format())
225
}
226
227
pub fn display(&self) -> PolarsResult<impl fmt::Display> {
228
struct DslPlanDisplay(IRPlan);
229
impl fmt::Display for DslPlanDisplay {
230
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231
fmt::Display::fmt(&self.0.as_ref().display(), f)
232
}
233
}
234
Ok(DslPlanDisplay(self.clone().to_alp()?))
235
}
236
237
pub fn to_alp(self) -> PolarsResult<IRPlan> {
238
let mut lp_arena = Arena::with_capacity(16);
239
let mut expr_arena = Arena::with_capacity(16);
240
241
let node = to_alp(
242
self,
243
&mut expr_arena,
244
&mut lp_arena,
245
&mut OptFlags::default(),
246
)?;
247
let plan = IRPlan::new(node, lp_arena, expr_arena);
248
249
Ok(plan)
250
}
251
252
#[cfg(feature = "serde")]
253
pub fn serialize_versioned<W: Write>(
254
&self,
255
mut writer: W,
256
ctx: PlanSerializationContext,
257
) -> PolarsResult<()> {
258
let le_major = DSL_VERSION.0.to_le_bytes();
259
let le_minor = DSL_VERSION.1.to_le_bytes();
260
261
// @GB:
262
// This is absolute horrendous but serde does not allow for state to passed along with the
263
// serialization so there is no proper way to do this except replace serde.
264
polars_utils::pl_serialize::USE_CLOUDPICKLE.set(ctx.use_cloudpickle);
265
266
writer.write_all(DSL_MAGIC_BYTES)?;
267
writer.write_all(&le_major)?;
268
writer.write_all(&le_minor)?;
269
writer.write_all(DSL_SCHEMA_HASH.as_bytes())?;
270
pl_serialize::serialize_dsl(writer, self)
271
}
272
273
#[cfg(feature = "serde")]
274
pub fn deserialize_versioned<R: Read>(mut reader: R) -> PolarsResult<Self> {
275
const MAGIC_LEN: usize = DSL_MAGIC_BYTES.len();
276
let mut version_magic = [0u8; MAGIC_LEN + 4];
277
reader
278
.read_exact(&mut version_magic)
279
.map_err(|e| polars_err!(ComputeError: "failed to read incoming DSL_VERSION: {e}"))?;
280
281
if &version_magic[..MAGIC_LEN] != DSL_MAGIC_BYTES {
282
polars_bail!(ComputeError: "dsl magic bytes not found")
283
}
284
285
let major = u16::from_le_bytes(version_magic[MAGIC_LEN..MAGIC_LEN + 2].try_into().unwrap());
286
let minor = u16::from_le_bytes(
287
version_magic[MAGIC_LEN + 2..MAGIC_LEN + 4]
288
.try_into()
289
.unwrap(),
290
);
291
292
const MAJOR: u16 = DSL_VERSION.0;
293
const MINOR: u16 = DSL_VERSION.1;
294
295
if polars_core::config::verbose() {
296
eprintln!(
297
"incoming DSL_VERSION: {major}.{minor}, deserializer DSL_VERSION: {MAJOR}.{MINOR}"
298
);
299
}
300
301
if major != MAJOR {
302
polars_bail!(ComputeError:
303
"deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is not compatible with this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}",
304
"error: can't deserialize DSL with a different major version"
305
);
306
}
307
308
if minor > MINOR {
309
polars_bail!(ComputeError:
310
"deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is not compatible with this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}",
311
"error: can't deserialize DSL with a higher minor version"
312
);
313
}
314
315
let mut schema_hash = [0_u8; SCHEMA_HASH_LEN];
316
reader.read_exact(&mut schema_hash).map_err(
317
|e| polars_err!(ComputeError: "failed to read incoming DSL_SCHEMA_HASH: {e}"),
318
)?;
319
let incoming_hash = SchemaHash::new(&schema_hash).ok_or_else(
320
|| polars_err!(ComputeError: "failed to read incoming DSL schema hash, not a valid hex string")
321
)?;
322
323
if polars_core::config::verbose() {
324
eprintln!(
325
"incoming DSL_SCHEMA_HASH: {incoming_hash}, deserializer DSL_SCHEMA_HASH: {DSL_SCHEMA_HASH}"
326
);
327
}
328
329
if incoming_hash != DSL_SCHEMA_HASH {
330
polars_bail!(ComputeError:
331
"deserialization failed\n\ngiven DSL_SCHEMA_HASH: {incoming_hash} is not compatible with this Polars version which uses DSL_SCHEMA_HASH: {DSL_SCHEMA_HASH}\n{}",
332
"error: can't deserialize DSL with incompatible schema"
333
);
334
}
335
336
pl_serialize::deserialize_dsl(reader)
337
.map_err(|e| polars_err!(ComputeError: "deserialization failed\n\nerror: {e}"))
338
}
339
340
#[cfg(feature = "dsl-schema")]
341
pub fn dsl_schema() -> schemars::schema::RootSchema {
342
use schemars::r#gen::SchemaSettings;
343
use schemars::schema::SchemaObject;
344
use schemars::visit::{Visitor, visit_schema_object};
345
346
#[derive(Clone, Copy, Debug)]
347
struct MyVisitor;
348
349
impl Visitor for MyVisitor {
350
fn visit_schema_object(&mut self, schema: &mut SchemaObject) {
351
// Remove descriptions auto-generated from doc comments
352
if schema.metadata.is_some() {
353
schema.metadata().description = None;
354
}
355
356
visit_schema_object(self, schema);
357
}
358
}
359
360
let mut schema = SchemaSettings::default()
361
.with_visitor(MyVisitor)
362
.into_generator()
363
.into_root_schema_for::<DslPlan>();
364
365
// Add the DSL schema hash as a top level field
366
schema
367
.schema
368
.extensions
369
.insert("hash".into(), DSL_SCHEMA_HASH.to_string().into());
370
371
schema
372
}
373
}
374
375
const SCHEMA_HASH_LEN: usize = 64;
376
377
struct SchemaHash<'a>(&'a str);
378
379
impl SchemaHash<'static> {
380
const fn from_hash_file() -> Self {
381
// Generated by build.rs
382
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/dsl-schema.sha256"));
383
Self::new(bytes).expect("not a valid hex string")
384
}
385
}
386
387
impl<'a> SchemaHash<'a> {
388
const fn new(bytes: &'a [u8; SCHEMA_HASH_LEN]) -> Option<Self> {
389
let mut i = 0;
390
while i < bytes.len() {
391
if !bytes[i].is_ascii_hexdigit() {
392
return None;
393
};
394
i += 1;
395
}
396
match str::from_utf8(bytes) {
397
Ok(hash) => Some(Self(hash)),
398
Err(_) => unreachable!(),
399
}
400
}
401
402
fn as_bytes(&self) -> &'a [u8; SCHEMA_HASH_LEN] {
403
self.0.as_bytes().try_into().unwrap()
404
}
405
}
406
407
impl PartialEq for SchemaHash<'_> {
408
fn eq(&self, other: &Self) -> bool {
409
self.0.eq_ignore_ascii_case(other.0)
410
}
411
}
412
413
impl std::fmt::Display for SchemaHash<'_> {
414
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
415
write!(f, "{}", self.0)
416
}
417
}
418
419