Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-sql/tests/udf.rs
6939 views
1
use polars_core::prelude::*;
2
use polars_lazy::prelude::IntoLazy;
3
use polars_plan::dsl::BaseColumnUdf;
4
use polars_plan::prelude::UserDefinedFunction;
5
use polars_sql::SQLContext;
6
use polars_sql::function_registry::FunctionRegistry;
7
8
struct MyFunctionRegistry {
9
functions: PlHashMap<String, UserDefinedFunction>,
10
}
11
12
impl MyFunctionRegistry {
13
fn new(funcs: Vec<UserDefinedFunction>) -> Self {
14
let functions = funcs.into_iter().map(|f| (f.name.to_string(), f)).collect();
15
MyFunctionRegistry { functions }
16
}
17
}
18
19
impl FunctionRegistry for MyFunctionRegistry {
20
fn register(&mut self, name: &str, fun: UserDefinedFunction) -> PolarsResult<()> {
21
self.functions.insert(name.to_string(), fun);
22
Ok(())
23
}
24
25
fn get_udf(&self, name: &str) -> PolarsResult<Option<UserDefinedFunction>> {
26
Ok(self.functions.get(name).cloned())
27
}
28
29
fn contains(&self, name: &str) -> bool {
30
self.functions.contains_key(name)
31
}
32
}
33
34
#[test]
35
fn test_udfs() -> PolarsResult<()> {
36
let my_custom_sum = UserDefinedFunction::new(
37
"my_custom_sum".into(),
38
BaseColumnUdf::new(
39
move |c: &mut [Column]| {
40
let first = c[0].as_materialized_series().clone();
41
let second = c[1].as_materialized_series().clone();
42
(first + second).map(Column::from)
43
},
44
|_: &Schema, fs: &[Field]| {
45
// UDF is responsible for schema validation
46
polars_ensure!(fs.len() == 2, SchemaMismatch: "expected two arguments");
47
let first = &fs[0];
48
let second = &fs[1];
49
50
if first.dtype() != second.dtype() {
51
polars_bail!(SchemaMismatch: "mismatched types")
52
}
53
Ok(first.clone())
54
},
55
),
56
);
57
58
let mut ctx = SQLContext::new()
59
.with_function_registry(Arc::new(MyFunctionRegistry::new(vec![my_custom_sum])));
60
61
let df = df! {
62
"a" => &[1, 2, 3],
63
"b" => &[1, 2, 3],
64
"c" => &["a", "b", "c"]
65
}
66
.unwrap()
67
.lazy();
68
69
ctx.register("foo", df);
70
let res = ctx.execute("SELECT a, b, my_custom_sum(a, b) FROM foo");
71
assert!(res.is_ok());
72
73
// schema is invalid so it will fail
74
assert!(matches!(
75
ctx.execute("SELECT a, b, my_custom_sum(c) as invalid FROM foo"),
76
Err(PolarsError::SchemaMismatch(_))
77
));
78
79
// create a new UDF to be registered on the context
80
let my_custom_divide = UserDefinedFunction::new(
81
"my_custom_divide".into(),
82
BaseColumnUdf::new(
83
move |c: &mut [Column]| {
84
let first = c[0].as_materialized_series().clone();
85
let second = c[1].as_materialized_series().clone();
86
(first / second).map(Column::from)
87
},
88
|_: &Schema, fs: &[Field]| {
89
// UDF is responsible for schema validation
90
polars_ensure!(fs.len() == 2, SchemaMismatch: "expected two arguments");
91
let first = &fs[0];
92
let second = &fs[1];
93
94
if first.dtype() != second.dtype() {
95
polars_bail!(SchemaMismatch: "mismatched types")
96
}
97
Ok(first.clone())
98
},
99
),
100
);
101
102
// register a new UDF on an existing context
103
ctx.registry_mut().register("my_div", my_custom_divide)?;
104
105
// execute the query
106
let res = ctx
107
.execute("SELECT a, b, my_div(a, b) as my_div FROM foo")?
108
.collect()?;
109
let expected = df! {
110
"a" => &[1, 2, 3],
111
"b" => &[1, 2, 3],
112
"my_div" => &[1, 1, 1]
113
}?;
114
assert!(expected.equals_missing(&res));
115
116
Ok(())
117
}
118
119