use polars_core::prelude::*;
use polars_lazy::prelude::IntoLazy;
use polars_plan::dsl::BaseColumnUdf;
use polars_plan::prelude::UserDefinedFunction;
use polars_sql::SQLContext;
use polars_sql::function_registry::FunctionRegistry;
struct MyFunctionRegistry {
functions: PlHashMap<String, UserDefinedFunction>,
}
impl MyFunctionRegistry {
fn new(funcs: Vec<UserDefinedFunction>) -> Self {
let functions = funcs.into_iter().map(|f| (f.name.to_string(), f)).collect();
MyFunctionRegistry { functions }
}
}
impl FunctionRegistry for MyFunctionRegistry {
fn register(&mut self, name: &str, fun: UserDefinedFunction) -> PolarsResult<()> {
self.functions.insert(name.to_string(), fun);
Ok(())
}
fn get_udf(&self, name: &str) -> PolarsResult<Option<UserDefinedFunction>> {
Ok(self.functions.get(name).cloned())
}
fn contains(&self, name: &str) -> bool {
self.functions.contains_key(name)
}
}
#[test]
fn test_udfs() -> PolarsResult<()> {
let my_custom_sum = UserDefinedFunction::new(
"my_custom_sum".into(),
BaseColumnUdf::new(
move |c: &mut [Column]| {
let first = c[0].as_materialized_series().clone();
let second = c[1].as_materialized_series().clone();
(first + second).map(Column::from)
},
|_: &Schema, fs: &[Field]| {
polars_ensure!(fs.len() == 2, SchemaMismatch: "expected two arguments");
let first = &fs[0];
let second = &fs[1];
if first.dtype() != second.dtype() {
polars_bail!(SchemaMismatch: "mismatched types")
}
Ok(first.clone())
},
),
);
let mut ctx = SQLContext::new()
.with_function_registry(Arc::new(MyFunctionRegistry::new(vec![my_custom_sum])));
let df = df! {
"a" => &[1, 2, 3],
"b" => &[1, 2, 3],
"c" => &["a", "b", "c"]
}
.unwrap()
.lazy();
ctx.register("foo", df);
let res = ctx.execute("SELECT a, b, my_custom_sum(a, b) FROM foo");
assert!(res.is_ok());
assert!(matches!(
ctx.execute("SELECT a, b, my_custom_sum(c) as invalid FROM foo"),
Err(PolarsError::SchemaMismatch(_))
));
let my_custom_divide = UserDefinedFunction::new(
"my_custom_divide".into(),
BaseColumnUdf::new(
move |c: &mut [Column]| {
let first = c[0].as_materialized_series().clone();
let second = c[1].as_materialized_series().clone();
(first / second).map(Column::from)
},
|_: &Schema, fs: &[Field]| {
polars_ensure!(fs.len() == 2, SchemaMismatch: "expected two arguments");
let first = &fs[0];
let second = &fs[1];
if first.dtype() != second.dtype() {
polars_bail!(SchemaMismatch: "mismatched types")
}
Ok(first.clone())
},
),
);
ctx.registry_mut().register("my_div", my_custom_divide)?;
let res = ctx
.execute("SELECT a, b, my_div(a, b) as my_div FROM foo")?
.collect()?;
let expected = df! {
"a" => &[1, 2, 3],
"b" => &[1, 2, 3],
"my_div" => &[1, 1, 1]
}?;
assert!(expected.equals_missing(&res));
Ok(())
}