Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/dsl/options/mod.rs
6940 views
1
use std::hash::Hash;
2
#[cfg(feature = "json")]
3
use std::num::NonZeroUsize;
4
use std::str::FromStr;
5
use std::sync::Arc;
6
7
mod sink;
8
9
use polars_core::error::PolarsResult;
10
use polars_core::prelude::*;
11
#[cfg(feature = "csv")]
12
use polars_io::csv::write::CsvWriterOptions;
13
#[cfg(feature = "ipc")]
14
use polars_io::ipc::IpcWriterOptions;
15
#[cfg(feature = "json")]
16
use polars_io::json::JsonWriterOptions;
17
#[cfg(feature = "parquet")]
18
use polars_io::parquet::write::ParquetWriteOptions;
19
#[cfg(feature = "iejoin")]
20
use polars_ops::frame::IEJoinOptions;
21
use polars_ops::frame::{CrossJoinFilter, CrossJoinOptions, JoinTypeOptions};
22
use polars_ops::prelude::{JoinArgs, JoinType};
23
#[cfg(feature = "dynamic_group_by")]
24
use polars_time::DynamicGroupOptions;
25
#[cfg(feature = "dynamic_group_by")]
26
use polars_time::RollingGroupOptions;
27
use polars_utils::IdxSize;
28
use polars_utils::pl_str::PlSmallStr;
29
#[cfg(feature = "serde")]
30
use serde::{Deserialize, Serialize};
31
pub use sink::*;
32
use strum_macros::IntoStaticStr;
33
34
use super::ExprIR;
35
use crate::dsl::Selector;
36
37
#[derive(Copy, Clone, PartialEq, Debug, Eq, Hash)]
38
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
39
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
40
pub struct RollingCovOptions {
41
pub window_size: IdxSize,
42
pub min_periods: IdxSize,
43
pub ddof: u8,
44
}
45
46
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
47
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
49
pub struct StrptimeOptions {
50
/// Formatting string
51
pub format: Option<PlSmallStr>,
52
/// If set then polars will return an error if any date parsing fails
53
pub strict: bool,
54
/// If polars may parse matches that not contain the whole string
55
/// e.g. "foo-2021-01-01-bar" could match "2021-01-01"
56
pub exact: bool,
57
/// use a cache of unique, converted dates to apply the datetime conversion.
58
pub cache: bool,
59
}
60
61
impl Default for StrptimeOptions {
62
fn default() -> Self {
63
StrptimeOptions {
64
format: None,
65
strict: true,
66
exact: true,
67
cache: true,
68
}
69
}
70
}
71
72
#[derive(Clone, PartialEq, Eq, IntoStaticStr, Debug)]
73
#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
74
#[strum(serialize_all = "snake_case")]
75
pub enum JoinTypeOptionsIR {
76
#[cfg(feature = "iejoin")]
77
IEJoin(IEJoinOptions),
78
// Fused cross join and filter (only used in the in-memory engine)
79
CrossAndFilter {
80
predicate: ExprIR, // Must be elementwise.
81
},
82
}
83
84
impl Hash for JoinTypeOptionsIR {
85
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
86
use JoinTypeOptionsIR::*;
87
match self {
88
#[cfg(feature = "iejoin")]
89
IEJoin(opt) => opt.hash(state),
90
CrossAndFilter { predicate } => predicate.node().hash(state),
91
}
92
}
93
}
94
95
impl JoinTypeOptionsIR {
96
pub fn compile<C: FnOnce(&ExprIR) -> PolarsResult<Arc<dyn CrossJoinFilter>>>(
97
self,
98
plan: C,
99
) -> PolarsResult<JoinTypeOptions> {
100
use JoinTypeOptionsIR::*;
101
match self {
102
CrossAndFilter { predicate } => {
103
let predicate = plan(&predicate)?;
104
105
Ok(JoinTypeOptions::Cross(CrossJoinOptions { predicate }))
106
},
107
#[cfg(feature = "iejoin")]
108
IEJoin(opt) => Ok(JoinTypeOptions::IEJoin(opt)),
109
}
110
}
111
}
112
113
#[derive(Clone, Debug, PartialEq, Hash)]
114
#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
115
pub struct JoinOptionsIR {
116
pub allow_parallel: bool,
117
pub force_parallel: bool,
118
pub args: JoinArgs,
119
pub options: Option<JoinTypeOptionsIR>,
120
/// Proxy of the number of rows in both sides of the joins
121
/// Holds `(Option<known_size>, estimated_size)`
122
pub rows_left: (Option<usize>, usize),
123
pub rows_right: (Option<usize>, usize),
124
}
125
126
impl From<JoinOptions> for JoinOptionsIR {
127
fn from(opts: JoinOptions) -> Self {
128
Self {
129
allow_parallel: opts.allow_parallel,
130
force_parallel: opts.force_parallel,
131
args: opts.args,
132
options: Default::default(),
133
rows_left: (None, usize::MAX),
134
rows_right: (None, usize::MAX),
135
}
136
}
137
}
138
139
#[derive(Clone, Debug, PartialEq, Hash)]
140
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
141
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
142
pub struct JoinOptions {
143
pub allow_parallel: bool,
144
pub force_parallel: bool,
145
pub args: JoinArgs,
146
}
147
148
impl Default for JoinOptions {
149
fn default() -> Self {
150
JoinOptions {
151
allow_parallel: true,
152
force_parallel: false,
153
// Todo!: make default
154
args: JoinArgs::new(JoinType::Left),
155
}
156
}
157
}
158
159
impl From<JoinOptionsIR> for JoinOptions {
160
fn from(opts: JoinOptionsIR) -> Self {
161
Self {
162
allow_parallel: opts.allow_parallel,
163
force_parallel: opts.force_parallel,
164
args: opts.args,
165
}
166
}
167
}
168
169
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
170
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
171
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
172
pub enum WindowType {
173
/// Explode the aggregated list and just do a hstack instead of a join
174
/// this requires the groups to be sorted to make any sense
175
Over(WindowMapping),
176
#[cfg(feature = "dynamic_group_by")]
177
Rolling(RollingGroupOptions),
178
}
179
180
impl From<WindowMapping> for WindowType {
181
fn from(value: WindowMapping) -> Self {
182
Self::Over(value)
183
}
184
}
185
186
impl Default for WindowType {
187
fn default() -> Self {
188
Self::Over(WindowMapping::default())
189
}
190
}
191
192
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash, IntoStaticStr)]
193
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
194
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
195
#[strum(serialize_all = "snake_case")]
196
pub enum WindowMapping {
197
/// Map the group values to the position
198
#[default]
199
GroupsToRows,
200
/// Explode the aggregated list and just do a hstack instead of a join
201
/// this requires the groups to be sorted to make any sense
202
Explode,
203
/// Join the groups as 'List<group_dtype>' to the row positions.
204
/// warning: this can be memory intensive
205
Join,
206
}
207
208
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
209
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
210
pub enum NestedType {
211
#[cfg(feature = "dtype-array")]
212
Array,
213
// List,
214
}
215
216
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
217
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
218
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
219
pub struct UnpivotArgsDSL {
220
pub on: Selector,
221
pub index: Selector,
222
pub variable_name: Option<PlSmallStr>,
223
pub value_name: Option<PlSmallStr>,
224
}
225
226
#[derive(Clone, Debug, Copy, Eq, PartialEq, Hash)]
227
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
228
pub enum Engine {
229
Auto,
230
Streaming,
231
InMemory,
232
Gpu,
233
}
234
235
impl FromStr for Engine {
236
type Err = String;
237
238
fn from_str(s: &str) -> Result<Self, Self::Err> {
239
match s {
240
// "cpu" for backwards compatibility
241
"auto" => Ok(Engine::Auto),
242
"cpu" | "in-memory" => Ok(Engine::InMemory),
243
"streaming" => Ok(Engine::Streaming),
244
"gpu" => Ok(Engine::Gpu),
245
"old-streaming" => Err("the 'old-streaming' engine has been removed".to_owned()),
246
v => Err(format!(
247
"`engine` must be one of {{'auto', 'in-memory', 'streaming', 'gpu'}}, got {v}",
248
)),
249
}
250
}
251
}
252
253
impl Engine {
254
pub fn into_static_str(self) -> &'static str {
255
match self {
256
Self::Auto => "auto",
257
Self::Streaming => "streaming",
258
Self::InMemory => "in-memory",
259
Self::Gpu => "gpu",
260
}
261
}
262
}
263
264
#[derive(Clone, Debug, Copy, Eq, PartialEq, Hash)]
265
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
266
pub struct UnionOptions {
267
pub slice: Option<(i64, usize)>,
268
// known row_output, estimated row output
269
pub rows: (Option<usize>, usize),
270
pub parallel: bool,
271
pub from_partitioned_ds: bool,
272
pub flattened_by_opt: bool,
273
pub rechunk: bool,
274
pub maintain_order: bool,
275
}
276
277
impl Default for UnionOptions {
278
fn default() -> Self {
279
Self {
280
slice: None,
281
rows: (None, 0),
282
parallel: true,
283
from_partitioned_ds: false,
284
flattened_by_opt: false,
285
rechunk: false,
286
maintain_order: true,
287
}
288
}
289
}
290
291
#[derive(Clone, Debug, Copy, Default, Eq, PartialEq, Hash)]
292
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
293
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
294
pub struct HConcatOptions {
295
pub parallel: bool,
296
}
297
298
#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)]
299
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
300
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
301
pub struct GroupbyOptions {
302
#[cfg(feature = "dynamic_group_by")]
303
pub dynamic: Option<DynamicGroupOptions>,
304
#[cfg(feature = "dynamic_group_by")]
305
pub rolling: Option<RollingGroupOptions>,
306
/// Take only a slice of the result
307
pub slice: Option<(i64, usize)>,
308
}
309
310
impl GroupbyOptions {
311
pub(crate) fn is_rolling(&self) -> bool {
312
#[cfg(feature = "dynamic_group_by")]
313
{
314
self.rolling.is_some()
315
}
316
#[cfg(not(feature = "dynamic_group_by"))]
317
{
318
false
319
}
320
}
321
322
pub(crate) fn is_dynamic(&self) -> bool {
323
#[cfg(feature = "dynamic_group_by")]
324
{
325
self.dynamic.is_some()
326
}
327
#[cfg(not(feature = "dynamic_group_by"))]
328
{
329
false
330
}
331
}
332
}
333
334
#[derive(Clone, Debug, Eq, PartialEq, Default, Hash)]
335
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
336
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
337
pub struct DistinctOptionsDSL {
338
/// Subset of columns that will be taken into account.
339
pub subset: Option<Selector>,
340
/// This will maintain the order of the input.
341
/// Note that this is more expensive.
342
/// `maintain_order` is not supported in the streaming
343
/// engine.
344
pub maintain_order: bool,
345
/// Which rows to keep.
346
pub keep_strategy: UniqueKeepStrategy,
347
}
348
349
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
350
pub struct LogicalPlanUdfOptions {
351
/// allow predicate pushdown optimizations
352
pub predicate_pd: bool,
353
/// allow projection pushdown optimizations
354
pub projection_pd: bool,
355
// used for formatting
356
pub fmt_str: &'static str,
357
}
358
359
#[derive(Clone, PartialEq, Eq, Debug, Default, Hash)]
360
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
361
pub struct AnonymousScanOptions {
362
pub skip_rows: Option<usize>,
363
pub fmt_str: &'static str,
364
}
365
366
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
367
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
368
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
369
pub enum FileType {
370
#[cfg(feature = "parquet")]
371
Parquet(ParquetWriteOptions),
372
#[cfg(feature = "ipc")]
373
Ipc(IpcWriterOptions),
374
#[cfg(feature = "csv")]
375
Csv(CsvWriterOptions),
376
#[cfg(feature = "json")]
377
Json(JsonWriterOptions),
378
}
379
380
impl FileType {
381
pub fn extension(&self) -> &'static str {
382
match self {
383
#[cfg(feature = "parquet")]
384
Self::Parquet(_) => "parquet",
385
#[cfg(feature = "ipc")]
386
Self::Ipc(_) => "ipc",
387
#[cfg(feature = "csv")]
388
Self::Csv(_) => "csv",
389
#[cfg(feature = "json")]
390
Self::Json(_) => "jsonl",
391
392
#[allow(unreachable_patterns)]
393
_ => unreachable!("enable file type features"),
394
}
395
}
396
}
397
398
//
399
// Arguments given to `concat`. Differs from `UnionOptions` as the latter is IR state.
400
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
401
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
402
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
403
pub struct UnionArgs {
404
pub parallel: bool,
405
pub rechunk: bool,
406
pub to_supertypes: bool,
407
pub diagonal: bool,
408
// If it is a union from a scan over multiple files.
409
pub from_partitioned_ds: bool,
410
pub maintain_order: bool,
411
}
412
413
impl Default for UnionArgs {
414
fn default() -> Self {
415
Self {
416
parallel: true,
417
rechunk: false,
418
to_supertypes: false,
419
diagonal: false,
420
from_partitioned_ds: false,
421
maintain_order: true,
422
}
423
}
424
}
425
426
impl From<UnionArgs> for UnionOptions {
427
fn from(args: UnionArgs) -> Self {
428
UnionOptions {
429
slice: None,
430
parallel: args.parallel,
431
rows: (None, 0),
432
from_partitioned_ds: args.from_partitioned_ds,
433
flattened_by_opt: false,
434
rechunk: args.rechunk,
435
maintain_order: args.maintain_order,
436
}
437
}
438
}
439
440
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
441
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
442
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
443
#[cfg(feature = "json")]
444
pub struct NDJsonReadOptions {
445
pub n_threads: Option<usize>,
446
pub infer_schema_length: Option<NonZeroUsize>,
447
pub chunk_size: NonZeroUsize,
448
pub low_memory: bool,
449
pub ignore_errors: bool,
450
pub schema: Option<SchemaRef>,
451
pub schema_overwrite: Option<SchemaRef>,
452
}
453
454