Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/options.rs
6940 views
1
use bitflags::bitflags;
2
use polars_core::prelude::*;
3
use polars_core::utils::SuperTypeOptions;
4
#[cfg(feature = "serde")]
5
use serde::{Deserialize, Serialize};
6
7
use crate::plans::PlSmallStr;
8
9
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
10
#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
11
pub struct DistinctOptionsIR {
12
/// Subset of columns that will be taken into account.
13
pub subset: Option<Arc<[PlSmallStr]>>,
14
/// This will maintain the order of the input.
15
/// Note that this is more expensive.
16
/// `maintain_order` is not supported in the streaming
17
/// engine.
18
pub maintain_order: bool,
19
/// Which rows to keep.
20
pub keep_strategy: UniqueKeepStrategy,
21
/// Take only a slice of the result
22
pub slice: Option<(i64, usize)>,
23
}
24
25
// a boolean that can only be set to `false` safely
26
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
27
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
29
pub struct UnsafeBool(bool);
30
impl Default for UnsafeBool {
31
fn default() -> Self {
32
UnsafeBool(true)
33
}
34
}
35
36
#[cfg(feature = "dsl-schema")]
37
impl schemars::JsonSchema for FunctionFlags {
38
fn schema_name() -> String {
39
"FunctionFlags".to_owned()
40
}
41
42
fn schema_id() -> std::borrow::Cow<'static, str> {
43
std::borrow::Cow::Borrowed(concat!(module_path!(), "::", "FunctionFlags"))
44
}
45
46
fn json_schema(_generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
47
use serde_json::{Map, Value};
48
49
let name_to_bits: Map<String, Value> = Self::all()
50
.iter_names()
51
.map(|(name, flag)| (name.to_owned(), flag.bits().into()))
52
.collect();
53
54
schemars::schema::Schema::Object(schemars::schema::SchemaObject {
55
instance_type: Some(schemars::schema::InstanceType::String.into()),
56
format: Some("bitflags".to_owned()),
57
extensions: schemars::Map::from_iter([
58
// Add a map of flag names and bit patterns to detect schema changes
59
("bitflags".to_owned(), Value::Object(name_to_bits)),
60
]),
61
..Default::default()
62
})
63
}
64
}
65
66
bitflags!(
67
#[repr(transparent)]
68
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
69
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
70
pub struct FunctionFlags: u16 {
71
/// The physical expression may rename the output of this function.
72
/// If set to `false` the physical engine will ensure the left input
73
/// expression is the output name.
74
const ALLOW_RENAME = 1 << 0;
75
/// if set, then the `Series` passed to the function in the group_by operation
76
/// will ensure the name is set. This is an extra heap allocation per group.
77
const PASS_NAME_TO_APPLY = 1 << 1;
78
/// There can be two ways of expanding wildcards:
79
///
80
/// Say the schema is 'a', 'b' and there is a function `f`. In this case, `f('*')` can expand
81
/// to:
82
/// 1. `f('a', 'b')`
83
/// 2. `f('a'), f('b')`
84
///
85
/// Setting this to true, will lead to behavior 1.
86
///
87
/// This also accounts for regex expansion.
88
const INPUT_WILDCARD_EXPANSION = 1 << 2;
89
/// Automatically explode on unit length if it ran as final aggregation.
90
///
91
/// this is the case for aggregations like sum, min, covariance etc.
92
/// We need to know this because we cannot see the difference between
93
/// the following functions based on the output type and number of elements:
94
///
95
/// x: {1, 2, 3}
96
///
97
/// head_1(x) -> {1}
98
/// sum(x) -> {4}
99
///
100
/// mutually exclusive with `RETURNS_SCALAR`
101
const RETURNS_SCALAR = 1 << 3;
102
/// This can happen with UDF's that use Polars within the UDF.
103
/// This can lead to recursively entering the engine and sometimes deadlocks.
104
/// This flag must be set to handle that.
105
const OPTIONAL_RE_ENTRANT = 1 << 4;
106
/// Whether this function allows no inputs.
107
const ALLOW_EMPTY_INPUTS = 1 << 5;
108
109
/// Given a function f and a column of values [v1, ..., vn]
110
/// f is row-separable i.f.f.
111
/// f([v1, ..., vn]) = concat(f(v1, ... vm), f(vm+1, ..., vn))
112
const ROW_SEPARABLE = 1 << 6;
113
/// Given a function f and a column of values [v1, ..., vn]
114
/// f is length preserving i.f.f. len(f([v1, ..., vn])) = n
115
///
116
/// mutually exclusive with `RETURNS_SCALAR`
117
const LENGTH_PRESERVING = 1 << 7;
118
/// NULLs on the first input are propagated to the output.
119
const PRESERVES_NULL_FIRST_INPUT = 1 << 8;
120
/// NULLs on any input are propagated to the output.
121
const PRESERVES_NULL_ALL_INPUTS = 1 << 9;
122
}
123
);
124
125
impl FunctionFlags {
126
pub fn set_elementwise(&mut self) {
127
*self |= Self::ROW_SEPARABLE | Self::LENGTH_PRESERVING;
128
}
129
130
pub fn is_elementwise(self) -> bool {
131
self.contains(Self::ROW_SEPARABLE | Self::LENGTH_PRESERVING)
132
}
133
134
pub fn is_row_separable(self) -> bool {
135
self.contains(Self::ROW_SEPARABLE)
136
}
137
138
pub fn is_length_preserving(self) -> bool {
139
self.contains(Self::LENGTH_PRESERVING)
140
}
141
142
pub fn returns_scalar(self) -> bool {
143
self.contains(Self::RETURNS_SCALAR)
144
}
145
}
146
147
impl Default for FunctionFlags {
148
fn default() -> Self {
149
Self::from_bits_truncate(0)
150
}
151
}
152
153
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
154
pub enum CastingRules {
155
/// Whether information may be lost during cast. E.g. a float to int is considered lossy,
156
/// whereas int to int is considered lossless.
157
/// Overflowing is not considered in this flag, that's handled in `strict` casting
158
FirstArgLossless,
159
Supertype(SuperTypeOptions),
160
}
161
162
impl CastingRules {
163
pub fn cast_to_supertypes() -> CastingRules {
164
Self::Supertype(Default::default())
165
}
166
}
167
168
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
169
#[cfg_attr(any(feature = "serde"), derive(Serialize, Deserialize))]
170
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
171
pub struct FunctionOptions {
172
// Validate the output of a `map`.
173
// this should always be true or we could OOB
174
pub check_lengths: UnsafeBool,
175
pub flags: FunctionFlags,
176
177
/// Options used when deciding how to cast the arguments of the function.
178
#[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
179
pub cast_options: Option<CastingRules>,
180
}
181
182
impl FunctionOptions {
183
#[cfg(feature = "fused")]
184
pub(crate) unsafe fn no_check_lengths(&mut self) {
185
self.check_lengths = UnsafeBool(false);
186
}
187
pub fn check_lengths(&self) -> bool {
188
self.check_lengths.0
189
}
190
191
pub fn set_elementwise(&mut self) {
192
self.flags.set_elementwise();
193
}
194
195
pub fn is_elementwise(&self) -> bool {
196
self.flags.is_elementwise()
197
}
198
199
pub fn is_length_preserving(&self) -> bool {
200
self.flags.contains(FunctionFlags::LENGTH_PRESERVING)
201
}
202
203
pub fn is_row_separable(&self) -> bool {
204
self.flags.is_row_separable()
205
}
206
207
pub fn returns_scalar(&self) -> bool {
208
self.flags.returns_scalar()
209
}
210
211
pub fn elementwise() -> FunctionOptions {
212
FunctionOptions {
213
..Default::default()
214
}
215
.with_flags(|f| f | FunctionFlags::ROW_SEPARABLE | FunctionFlags::LENGTH_PRESERVING)
216
}
217
218
pub fn elementwise_with_infer() -> FunctionOptions {
219
Self::length_preserving()
220
}
221
222
pub fn row_separable() -> FunctionOptions {
223
FunctionOptions {
224
..Default::default()
225
}
226
.with_flags(|f| f | FunctionFlags::ROW_SEPARABLE)
227
}
228
229
pub fn length_preserving() -> FunctionOptions {
230
FunctionOptions {
231
..Default::default()
232
}
233
.with_flags(|f| f | FunctionFlags::LENGTH_PRESERVING)
234
}
235
236
pub fn groupwise() -> FunctionOptions {
237
FunctionOptions {
238
..Default::default()
239
}
240
}
241
242
pub fn aggregation() -> FunctionOptions {
243
let mut options = Self::groupwise();
244
options.flags |= FunctionFlags::RETURNS_SCALAR;
245
options
246
}
247
248
pub fn with_supertyping(self, supertype_options: SuperTypeOptions) -> FunctionOptions {
249
self.with_casting_rules(CastingRules::Supertype(supertype_options))
250
}
251
252
pub fn with_casting_rules(mut self, casting_rules: CastingRules) -> FunctionOptions {
253
self.cast_options = Some(casting_rules);
254
self
255
}
256
257
pub fn with_flags(mut self, f: impl Fn(FunctionFlags) -> FunctionFlags) -> FunctionOptions {
258
self.flags = f(self.flags);
259
self
260
}
261
}
262
263
impl Default for FunctionOptions {
264
fn default() -> Self {
265
FunctionOptions {
266
check_lengths: UnsafeBool(true),
267
cast_options: Default::default(),
268
flags: Default::default(),
269
}
270
}
271
}
272
273
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
274
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
275
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
276
pub struct ProjectionOptions {
277
pub run_parallel: bool,
278
pub duplicate_check: bool,
279
// Should length-1 Series be broadcast to the length of the dataframe.
280
// Only used by CSE optimizer
281
pub should_broadcast: bool,
282
}
283
284
impl Default for ProjectionOptions {
285
fn default() -> Self {
286
Self {
287
run_parallel: true,
288
duplicate_check: true,
289
should_broadcast: true,
290
}
291
}
292
}
293
294
impl ProjectionOptions {
295
/// Conservatively merge the options of two [`ProjectionOptions`]
296
pub fn merge_options(&self, other: &Self) -> Self {
297
Self {
298
run_parallel: self.run_parallel & other.run_parallel,
299
duplicate_check: self.duplicate_check & other.duplicate_check,
300
should_broadcast: self.should_broadcast | other.should_broadcast,
301
}
302
}
303
}
304
305