Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/pyo3-polars/example/derive_expression/expression_lib/src/expressions.rs
7884 views
1
use std::fmt::Write;
2
3
use polars::prelude::*;
4
use polars_plan::prelude::FieldsMapper;
5
use pyo3_polars::derive::{polars_expr, CallerContext};
6
use pyo3_polars::export::polars_core::POOL;
7
use serde::Deserialize;
8
9
#[derive(Deserialize)]
10
struct PigLatinKwargs {
11
capitalize: bool,
12
}
13
14
fn pig_latin_str(value: &str, capitalize: bool, output: &mut String) {
15
if let Some(first_char) = value.chars().next() {
16
if capitalize {
17
for c in value.chars().skip(1).map(|char| char.to_uppercase()) {
18
write!(output, "{c}").unwrap()
19
}
20
write!(output, "AY").unwrap()
21
} else {
22
let offset = first_char.len_utf8();
23
write!(output, "{}{}ay", &value[offset..], first_char).unwrap()
24
}
25
}
26
}
27
28
#[polars_expr(output_type=String)]
29
fn pig_latinnify(inputs: &[Series], kwargs: PigLatinKwargs) -> PolarsResult<Series> {
30
let ca = inputs[0].str()?;
31
let out: StringChunked = ca.apply_into_string_amortized(|value, output| {
32
pig_latin_str(value, kwargs.capitalize, output)
33
});
34
Ok(out.into_series())
35
}
36
37
fn split_offsets(len: usize, n: usize) -> Vec<(usize, usize)> {
38
if n == 1 {
39
vec![(0, len)]
40
} else {
41
let chunk_size = len / n;
42
43
(0..n)
44
.map(|partition| {
45
let offset = partition * chunk_size;
46
let len = if partition == (n - 1) {
47
len - offset
48
} else {
49
chunk_size
50
};
51
(partition * chunk_size, len)
52
})
53
.collect()
54
}
55
}
56
57
/// This expression will run in parallel if the `context` allows it.
58
#[polars_expr(output_type=String)]
59
fn pig_latinnify_with_parallelism(
60
inputs: &[Series],
61
context: CallerContext,
62
kwargs: PigLatinKwargs,
63
) -> PolarsResult<Series> {
64
use rayon::prelude::*;
65
let ca = inputs[0].str()?;
66
67
if context.parallel() {
68
let out: StringChunked = ca.apply_into_string_amortized(|value, output| {
69
pig_latin_str(value, kwargs.capitalize, output)
70
});
71
Ok(out.into_series())
72
} else {
73
POOL.install(|| {
74
let n_threads = POOL.current_num_threads();
75
let splits = split_offsets(ca.len(), n_threads);
76
77
let chunks: Vec<_> = splits
78
.into_par_iter()
79
.map(|(offset, len)| {
80
let sliced = ca.slice(offset as i64, len);
81
let out = sliced.apply_into_string_amortized(|value, output| {
82
pig_latin_str(value, kwargs.capitalize, output)
83
});
84
out.downcast_iter().cloned().collect::<Vec<_>>()
85
})
86
.collect();
87
88
Ok(
89
StringChunked::from_chunk_iter(ca.name().clone(), chunks.into_iter().flatten())
90
.into_series(),
91
)
92
})
93
}
94
}
95
96
#[polars_expr(output_type=Float64)]
97
fn jaccard_similarity(inputs: &[Series]) -> PolarsResult<Series> {
98
let a = inputs[0].list()?;
99
let b = inputs[1].list()?;
100
crate::distances::naive_jaccard_sim(a, b).map(|ca| ca.into_series())
101
}
102
103
#[polars_expr(output_type=Float64)]
104
fn hamming_distance(inputs: &[Series]) -> PolarsResult<Series> {
105
let a = inputs[0].str()?;
106
let b = inputs[1].str()?;
107
let out: UInt32Chunked =
108
arity::binary_elementwise_values(a, b, crate::distances::naive_hamming_dist);
109
Ok(out.into_series())
110
}
111
112
fn haversine_output(input_fields: &[Field]) -> PolarsResult<Field> {
113
FieldsMapper::new(input_fields).map_to_float_dtype()
114
}
115
116
#[polars_expr(output_type_func=haversine_output)]
117
fn haversine(inputs: &[Series]) -> PolarsResult<Series> {
118
let out = match inputs[0].dtype() {
119
DataType::Float32 => {
120
let start_lat = inputs[0].f32().unwrap();
121
let start_long = inputs[1].f32().unwrap();
122
let end_lat = inputs[2].f32().unwrap();
123
let end_long = inputs[3].f32().unwrap();
124
crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)?
125
.into_series()
126
},
127
DataType::Float64 => {
128
let start_lat = inputs[0].f64().unwrap();
129
let start_long = inputs[1].f64().unwrap();
130
let end_lat = inputs[2].f64().unwrap();
131
let end_long = inputs[3].f64().unwrap();
132
crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)?
133
.into_series()
134
},
135
_ => unimplemented!(),
136
};
137
Ok(out)
138
}
139
140
/// The `DefaultKwargs` isn't very ergonomic as it doesn't validate any schema.
141
/// Provide your own kwargs struct with the proper schema and accept that type
142
/// in your plugin expression.
143
#[derive(Deserialize)]
144
pub struct MyKwargs {
145
float_arg: f64,
146
integer_arg: i64,
147
string_arg: String,
148
boolean_arg: bool,
149
}
150
151
/// If you want to accept `kwargs`. You define a `kwargs` argument
152
/// on the second position in you plugin. You can provide any custom struct that is deserializable
153
/// with the pickle protocol (on the rust side).
154
#[polars_expr(output_type=String)]
155
fn append_kwargs(input: &[Series], kwargs: MyKwargs) -> PolarsResult<Series> {
156
let input = &input[0];
157
let input = input.cast(&DataType::String)?;
158
let ca = input.str().unwrap();
159
160
Ok(ca
161
.apply_into_string_amortized(|val, buf| {
162
write!(
163
buf,
164
"{}-{}-{}-{}-{}",
165
val, kwargs.float_arg, kwargs.integer_arg, kwargs.string_arg, kwargs.boolean_arg
166
)
167
.unwrap()
168
})
169
.into_series())
170
}
171
172
#[polars_expr(output_type=Boolean)]
173
fn is_leap_year(input: &[Series]) -> PolarsResult<Series> {
174
let input = &input[0];
175
let ca = input.date()?;
176
177
let out: BooleanChunked = ca
178
.as_date_iter()
179
.map(|opt_dt| opt_dt.map(|dt| dt.leap_year()))
180
.collect_ca(ca.name().clone());
181
182
Ok(out.into_series())
183
}
184
185
#[polars_expr(output_type=Boolean)]
186
fn panic(_input: &[Series]) -> PolarsResult<Series> {
187
todo!()
188
}
189
190
#[derive(Deserialize)]
191
struct TimeZone {
192
tz: String,
193
}
194
195
fn convert_timezone(input_fields: &[Field], kwargs: TimeZone) -> PolarsResult<Field> {
196
FieldsMapper::new(input_fields).try_map_dtype(|dtype| match dtype {
197
DataType::Datetime(tu, _) => Ok(DataType::Datetime(
198
*tu,
199
datatypes::TimeZone::opt_try_new(Some(kwargs.tz))?,
200
)),
201
_ => polars_bail!(ComputeError: "expected datetime"),
202
})
203
}
204
205
/// This expression is for demonstration purposes as we have a dedicated
206
/// `convert_time_zone` in Polars.
207
#[polars_expr(output_type_func_with_kwargs=convert_timezone)]
208
fn change_time_zone(input: &[Series], kwargs: TimeZone) -> PolarsResult<Series> {
209
let input = &input[0];
210
let ca = input.datetime()?;
211
212
let mut out = ca.clone();
213
214
let Some(timezone) = datatypes::TimeZone::opt_try_new(Some(kwargs.tz))? else {
215
polars_bail!(ComputeError: "expected timezone")
216
};
217
218
out.set_time_zone(timezone)?;
219
Ok(out.into_series())
220
}
221
222