Path: blob/main/crates/polars-python/src/expr/selector.rs
7889 views
use std::hash::{Hash, Hasher};1use std::sync::Arc;23use polars::prelude::{4DataType, DataTypeSelector, Selector, TimeUnit, TimeUnitSet, TimeZone, TimeZoneSet,5};6use polars_plan::dsl;7use pyo3::exceptions::PyTypeError;8use pyo3::{PyResult, pyclass};910use crate::prelude::Wrap;1112#[pyclass(frozen)]13#[repr(transparent)]14#[derive(Clone)]15pub struct PySelector {16pub inner: Selector,17}1819impl From<Selector> for PySelector {20fn from(inner: Selector) -> Self {21Self { inner }22}23}2425fn parse_time_unit_set(time_units: Vec<Wrap<TimeUnit>>) -> TimeUnitSet {26let mut tu = TimeUnitSet::empty();27for v in time_units {28match v.0 {29TimeUnit::Nanoseconds => tu |= TimeUnitSet::NANO_SECONDS,30TimeUnit::Microseconds => tu |= TimeUnitSet::MICRO_SECONDS,31TimeUnit::Milliseconds => tu |= TimeUnitSet::MILLI_SECONDS,32}33}34tu35}3637pub fn parse_datatype_selector(selector: PySelector) -> PyResult<DataTypeSelector> {38selector.inner.to_dtype_selector().ok_or_else(|| {39PyTypeError::new_err(format!(40"expected datatype based expression got '{}'",41selector.inner42))43})44}4546#[cfg(feature = "pymethods")]47#[pyo3::pymethods]48impl PySelector {49fn union(&self, other: &Self) -> Self {50Self {51inner: self.inner.clone() | other.inner.clone(),52}53}5455fn difference(&self, other: &Self) -> Self {56Self {57inner: self.inner.clone() - other.inner.clone(),58}59}6061fn exclusive_or(&self, other: &Self) -> Self {62Self {63inner: self.inner.clone() ^ other.inner.clone(),64}65}6667fn intersect(&self, other: &Self) -> Self {68Self {69inner: self.inner.clone() & other.inner.clone(),70}71}7273#[staticmethod]74fn by_dtype(dtypes: Vec<Wrap<DataType>>) -> Self {75let dtypes = dtypes.into_iter().map(|x| x.0).collect::<Vec<_>>();76dsl::dtype_cols(dtypes).as_selector().into()77}7879#[staticmethod]80fn by_name(names: Vec<String>, strict: bool) -> Self {81dsl::by_name(names, strict).into()82}8384#[staticmethod]85fn by_index(indices: Vec<i64>, strict: bool) -> Self {86Selector::ByIndex {87indices: indices.into(),88strict,89}90.into()91}9293#[staticmethod]94fn first(strict: bool) -> Self {95Selector::ByIndex {96indices: [0].into(),97strict,98}99.into()100}101102#[staticmethod]103fn last(strict: bool) -> Self {104Selector::ByIndex {105indices: [-1].into(),106strict,107}108.into()109}110111#[staticmethod]112fn matches(pattern: String) -> Self {113Selector::Matches(pattern.into()).into()114}115116#[staticmethod]117fn enum_() -> Self {118DataTypeSelector::Enum.as_selector().into()119}120121#[staticmethod]122fn categorical() -> Self {123DataTypeSelector::Categorical.as_selector().into()124}125126#[staticmethod]127fn nested() -> Self {128DataTypeSelector::Nested.as_selector().into()129}130131#[staticmethod]132fn list(inner_dst: Option<Self>) -> PyResult<Self> {133let inner_dst = match inner_dst {134None => None,135Some(inner_dst) => Some(Arc::new(parse_datatype_selector(inner_dst)?)),136};137Ok(DataTypeSelector::List(inner_dst).as_selector().into())138}139140#[staticmethod]141fn array(inner_dst: Option<Self>, width: Option<usize>) -> PyResult<Self> {142let inner_dst = match inner_dst {143None => None,144Some(inner_dst) => Some(Arc::new(parse_datatype_selector(inner_dst)?)),145};146Ok(DataTypeSelector::Array(inner_dst, width)147.as_selector()148.into())149}150151#[staticmethod]152fn struct_() -> Self {153DataTypeSelector::Struct.as_selector().into()154}155156#[staticmethod]157fn integer() -> Self {158DataTypeSelector::Integer.as_selector().into()159}160161#[staticmethod]162fn signed_integer() -> Self {163DataTypeSelector::SignedInteger.as_selector().into()164}165166#[staticmethod]167fn unsigned_integer() -> Self {168DataTypeSelector::UnsignedInteger.as_selector().into()169}170171#[staticmethod]172fn float() -> Self {173DataTypeSelector::Float.as_selector().into()174}175176#[staticmethod]177fn decimal() -> Self {178DataTypeSelector::Decimal.as_selector().into()179}180181#[staticmethod]182fn numeric() -> Self {183DataTypeSelector::Numeric.as_selector().into()184}185186#[staticmethod]187fn temporal() -> Self {188DataTypeSelector::Temporal.as_selector().into()189}190191#[staticmethod]192fn datetime(tu: Vec<Wrap<TimeUnit>>, tz: Vec<Wrap<Option<TimeZone>>>) -> Self {193use TimeZoneSet as TZS;194195let mut allow_unset = false;196let mut allow_set = false;197let mut any_of: Vec<TimeZone> = Vec::new();198199let tu = parse_time_unit_set(tu);200for t in tz {201let t = t.0;202match t {203None => allow_unset = true,204Some(s) if s.as_str() == "*" => allow_set = true,205Some(t) => any_of.push(t),206}207}208209let tzs = match (allow_unset, allow_set) {210(true, true) => TZS::Any,211(false, true) => TZS::AnySet,212(true, false) if any_of.is_empty() => TZS::Unset,213(true, false) => TZS::UnsetOrAnyOf(any_of.into()),214(false, false) => TZS::AnyOf(any_of.into()),215};216DataTypeSelector::Datetime(tu, tzs).as_selector().into()217}218219#[staticmethod]220fn duration(tu: Vec<Wrap<TimeUnit>>) -> Self {221let tu = parse_time_unit_set(tu);222DataTypeSelector::Duration(tu).as_selector().into()223}224225#[staticmethod]226fn object() -> Self {227DataTypeSelector::Object.as_selector().into()228}229230#[staticmethod]231fn empty() -> Self {232dsl::empty().into()233}234235#[staticmethod]236fn all() -> Self {237dsl::all().into()238}239240fn hash(&self) -> u64 {241let mut hasher = std::hash::DefaultHasher::default();242self.inner.hash(&mut hasher);243hasher.finish()244}245}246247248