Path: blob/main/crates/polars-plan/src/plans/python/pyarrow.rs
8384 views
use std::fmt::Write;12use polars_core::datatypes::AnyValue;3use polars_core::prelude::{TimeUnit, TimeZone};45use crate::prelude::*;67#[derive(Default, Copy, Clone)]8pub struct PyarrowArgs {9// pyarrow doesn't allow `filter([True, False])`10// but does allow `filter(field("a").isin([True, False]))`11allow_literal_series: bool,12}1314fn to_py_datetime(v: i64, tu: &TimeUnit, tz: Option<&TimeZone>) -> String {15// note: `to_py_datetime` and the `Datetime`16// dtype have to be in-scope on the python side17match tz {18None => format!("to_py_datetime({},'{}')", v, tu.to_ascii()),19Some(tz) => format!("to_py_datetime({},'{}','{}')", v, tu.to_ascii(), tz),20}21}2223fn sanitize(name: &str) -> Option<&str> {24if name.chars().all(|c| match c {25' ' => true,26'-' => true,27'_' => true,28c => c.is_alphanumeric(),29}) {30Some(name)31} else {32None33}34}3536// convert to a pyarrow expression that can be evaluated with pythons eval37pub fn predicate_to_pa(38predicate: Node,39expr_arena: &Arena<AExpr>,40args: PyarrowArgs,41) -> Option<String> {42match expr_arena.get(predicate) {43AExpr::BinaryExpr { left, right, op } => {44if op.is_comparison_or_bitwise() {45let left = predicate_to_pa(*left, expr_arena, args)?;46let right = predicate_to_pa(*right, expr_arena, args)?;47Some(format!("({left} {op} {right})"))48} else {49None50}51},52AExpr::Column(name) => {53let name = sanitize(name)?;54Some(format!("pa.compute.field('{name}')"))55},56AExpr::Literal(LiteralValue::Series(s)) => {57if !args.allow_literal_series || s.is_empty() || s.len() > 100 {58None59} else {60let mut list_repr = String::with_capacity(s.len() * 5);61list_repr.push('[');62for av in s.iter() {63match av {64AnyValue::Boolean(v) => {65let s = if v { "True" } else { "False" };66write!(list_repr, "{s},").unwrap();67},68#[cfg(feature = "dtype-datetime")]69AnyValue::Datetime(v, tu, tz) => {70let dtm = to_py_datetime(v, &tu, tz);71write!(list_repr, "{dtm},").unwrap();72},73#[cfg(feature = "dtype-date")]74AnyValue::Date(v) => {75write!(list_repr, "to_py_date({v}),").unwrap();76},77AnyValue::String(s) => {78let _ = sanitize(s)?;79write!(list_repr, "{av},").unwrap();80},81// Hard to sanitize82AnyValue::Binary(_) | AnyValue::List(_) => return None,83#[cfg(feature = "dtype-array")]84AnyValue::Array(_, _) => return None,85#[cfg(feature = "dtype-struct")]86AnyValue::Struct(_, _, _) => return None,87_ => {88write!(list_repr, "{av},").unwrap();89},90}91}92// pop last comma93list_repr.pop();94list_repr.push(']');95Some(list_repr)96}97},98AExpr::Literal(lv) => {99let av = lv.to_any_value()?;100let dtype = av.dtype();101match av.as_borrowed() {102AnyValue::String(s) => {103let s = sanitize(s)?;104Some(format!("'{s}'"))105},106AnyValue::Boolean(val) => {107// python bools are capitalized108if val {109Some("pa.compute.scalar(True)".to_string())110} else {111Some("pa.compute.scalar(False)".to_string())112}113},114#[cfg(feature = "dtype-date")]115AnyValue::Date(v) => {116// the function `to_py_date` and the `Date`117// dtype have to be in scope on the python side118Some(format!("to_py_date({v})"))119},120#[cfg(feature = "dtype-datetime")]121AnyValue::Datetime(v, tu, tz) => Some(to_py_datetime(v, &tu, tz)),122// Hard to sanitize123AnyValue::Binary(_) | AnyValue::List(_) => None,124#[cfg(feature = "dtype-array")]125AnyValue::Array(_, _) => None,126#[cfg(feature = "dtype-struct")]127AnyValue::Struct(_, _, _) => None,128// Activate once pyarrow supports them129// #[cfg(feature = "dtype-time")]130// AnyValue::Time(v) => {131// // the function `to_py_time` has to be in scope132// // on the python side133// Some(format!("to_py_time(value={v})"))134// }135// #[cfg(feature = "dtype-duration")]136// AnyValue::Duration(v, tu) => {137// // the function `to_py_timedelta` has to be in scope138// // on the python side139// Some(format!(140// "to_py_timedelta(value={}, tu='{}')",141// v,142// tu.to_ascii()143// ))144// }145av => {146if dtype.is_float() {147let val = av.extract::<f64>()?;148Some(format!("{val}"))149} else if dtype.is_integer() {150let val = av.extract::<i64>()?;151Some(format!("{val}"))152} else {153None154}155},156}157},158#[cfg(feature = "is_in")]159AExpr::Function {160function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { .. }),161input,162..163} => {164let col = predicate_to_pa(input.first()?.node(), expr_arena, args)?;165let mut args = args;166args.allow_literal_series = true;167let values = predicate_to_pa(input.get(1)?.node(), expr_arena, args)?;168169Some(format!("({col}).isin({values})"))170},171#[cfg(feature = "is_between")]172AExpr::Function {173function: IRFunctionExpr::Boolean(IRBooleanFunction::IsBetween { closed }),174input,175..176} => {177if !matches!(expr_arena.get(input.first()?.node()), AExpr::Column(_)) {178None179} else {180let col = predicate_to_pa(input.first()?.node(), expr_arena, args)?;181let left_cmp_op = match closed {182ClosedInterval::None | ClosedInterval::Right => Operator::Gt,183ClosedInterval::Both | ClosedInterval::Left => Operator::GtEq,184};185let right_cmp_op = match closed {186ClosedInterval::None | ClosedInterval::Left => Operator::Lt,187ClosedInterval::Both | ClosedInterval::Right => Operator::LtEq,188};189190let lower = predicate_to_pa(input.get(1)?.node(), expr_arena, args)?;191let upper = predicate_to_pa(input.get(2)?.node(), expr_arena, args)?;192193Some(format!(194"(({col} {left_cmp_op} {lower}) & ({col} {right_cmp_op} {upper}))"195))196}197},198AExpr::Function {199function, input, ..200} => {201let input = input.first().unwrap().node();202let input = predicate_to_pa(input, expr_arena, args)?;203204match function {205IRFunctionExpr::Boolean(IRBooleanFunction::Not) => Some(format!("~({input})")),206IRFunctionExpr::Boolean(IRBooleanFunction::IsNull) => {207Some(format!("({input}).is_null()"))208},209IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull) => {210Some(format!("~({input}).is_null()"))211},212_ => None,213}214},215_ => None,216}217}218219220