Path: blob/main/pyo3-polars/example/derive_expression/expression_lib/src/expressions.rs
7884 views
use std::fmt::Write;12use polars::prelude::*;3use polars_plan::prelude::FieldsMapper;4use pyo3_polars::derive::{polars_expr, CallerContext};5use pyo3_polars::export::polars_core::POOL;6use serde::Deserialize;78#[derive(Deserialize)]9struct PigLatinKwargs {10capitalize: bool,11}1213fn pig_latin_str(value: &str, capitalize: bool, output: &mut String) {14if let Some(first_char) = value.chars().next() {15if capitalize {16for c in value.chars().skip(1).map(|char| char.to_uppercase()) {17write!(output, "{c}").unwrap()18}19write!(output, "AY").unwrap()20} else {21let offset = first_char.len_utf8();22write!(output, "{}{}ay", &value[offset..], first_char).unwrap()23}24}25}2627#[polars_expr(output_type=String)]28fn pig_latinnify(inputs: &[Series], kwargs: PigLatinKwargs) -> PolarsResult<Series> {29let ca = inputs[0].str()?;30let out: StringChunked = ca.apply_into_string_amortized(|value, output| {31pig_latin_str(value, kwargs.capitalize, output)32});33Ok(out.into_series())34}3536fn split_offsets(len: usize, n: usize) -> Vec<(usize, usize)> {37if n == 1 {38vec![(0, len)]39} else {40let chunk_size = len / n;4142(0..n)43.map(|partition| {44let offset = partition * chunk_size;45let len = if partition == (n - 1) {46len - offset47} else {48chunk_size49};50(partition * chunk_size, len)51})52.collect()53}54}5556/// This expression will run in parallel if the `context` allows it.57#[polars_expr(output_type=String)]58fn pig_latinnify_with_parallelism(59inputs: &[Series],60context: CallerContext,61kwargs: PigLatinKwargs,62) -> PolarsResult<Series> {63use rayon::prelude::*;64let ca = inputs[0].str()?;6566if context.parallel() {67let out: StringChunked = ca.apply_into_string_amortized(|value, output| {68pig_latin_str(value, kwargs.capitalize, output)69});70Ok(out.into_series())71} else {72POOL.install(|| {73let n_threads = POOL.current_num_threads();74let splits = split_offsets(ca.len(), n_threads);7576let chunks: Vec<_> = splits77.into_par_iter()78.map(|(offset, len)| {79let sliced = ca.slice(offset as i64, len);80let out = sliced.apply_into_string_amortized(|value, output| {81pig_latin_str(value, kwargs.capitalize, output)82});83out.downcast_iter().cloned().collect::<Vec<_>>()84})85.collect();8687Ok(88StringChunked::from_chunk_iter(ca.name().clone(), chunks.into_iter().flatten())89.into_series(),90)91})92}93}9495#[polars_expr(output_type=Float64)]96fn jaccard_similarity(inputs: &[Series]) -> PolarsResult<Series> {97let a = inputs[0].list()?;98let b = inputs[1].list()?;99crate::distances::naive_jaccard_sim(a, b).map(|ca| ca.into_series())100}101102#[polars_expr(output_type=Float64)]103fn hamming_distance(inputs: &[Series]) -> PolarsResult<Series> {104let a = inputs[0].str()?;105let b = inputs[1].str()?;106let out: UInt32Chunked =107arity::binary_elementwise_values(a, b, crate::distances::naive_hamming_dist);108Ok(out.into_series())109}110111fn haversine_output(input_fields: &[Field]) -> PolarsResult<Field> {112FieldsMapper::new(input_fields).map_to_float_dtype()113}114115#[polars_expr(output_type_func=haversine_output)]116fn haversine(inputs: &[Series]) -> PolarsResult<Series> {117let out = match inputs[0].dtype() {118DataType::Float32 => {119let start_lat = inputs[0].f32().unwrap();120let start_long = inputs[1].f32().unwrap();121let end_lat = inputs[2].f32().unwrap();122let end_long = inputs[3].f32().unwrap();123crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)?124.into_series()125},126DataType::Float64 => {127let start_lat = inputs[0].f64().unwrap();128let start_long = inputs[1].f64().unwrap();129let end_lat = inputs[2].f64().unwrap();130let end_long = inputs[3].f64().unwrap();131crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)?132.into_series()133},134_ => unimplemented!(),135};136Ok(out)137}138139/// The `DefaultKwargs` isn't very ergonomic as it doesn't validate any schema.140/// Provide your own kwargs struct with the proper schema and accept that type141/// in your plugin expression.142#[derive(Deserialize)]143pub struct MyKwargs {144float_arg: f64,145integer_arg: i64,146string_arg: String,147boolean_arg: bool,148}149150/// If you want to accept `kwargs`. You define a `kwargs` argument151/// on the second position in you plugin. You can provide any custom struct that is deserializable152/// with the pickle protocol (on the rust side).153#[polars_expr(output_type=String)]154fn append_kwargs(input: &[Series], kwargs: MyKwargs) -> PolarsResult<Series> {155let input = &input[0];156let input = input.cast(&DataType::String)?;157let ca = input.str().unwrap();158159Ok(ca160.apply_into_string_amortized(|val, buf| {161write!(162buf,163"{}-{}-{}-{}-{}",164val, kwargs.float_arg, kwargs.integer_arg, kwargs.string_arg, kwargs.boolean_arg165)166.unwrap()167})168.into_series())169}170171#[polars_expr(output_type=Boolean)]172fn is_leap_year(input: &[Series]) -> PolarsResult<Series> {173let input = &input[0];174let ca = input.date()?;175176let out: BooleanChunked = ca177.as_date_iter()178.map(|opt_dt| opt_dt.map(|dt| dt.leap_year()))179.collect_ca(ca.name().clone());180181Ok(out.into_series())182}183184#[polars_expr(output_type=Boolean)]185fn panic(_input: &[Series]) -> PolarsResult<Series> {186todo!()187}188189#[derive(Deserialize)]190struct TimeZone {191tz: String,192}193194fn convert_timezone(input_fields: &[Field], kwargs: TimeZone) -> PolarsResult<Field> {195FieldsMapper::new(input_fields).try_map_dtype(|dtype| match dtype {196DataType::Datetime(tu, _) => Ok(DataType::Datetime(197*tu,198datatypes::TimeZone::opt_try_new(Some(kwargs.tz))?,199)),200_ => polars_bail!(ComputeError: "expected datetime"),201})202}203204/// This expression is for demonstration purposes as we have a dedicated205/// `convert_time_zone` in Polars.206#[polars_expr(output_type_func_with_kwargs=convert_timezone)]207fn change_time_zone(input: &[Series], kwargs: TimeZone) -> PolarsResult<Series> {208let input = &input[0];209let ca = input.datetime()?;210211let mut out = ca.clone();212213let Some(timezone) = datatypes::TimeZone::opt_try_new(Some(kwargs.tz))? else {214polars_bail!(ComputeError: "expected timezone")215};216217out.set_time_zone(timezone)?;218Ok(out.into_series())219}220221222