Path: blob/main/crates/polars-expr/src/expressions/slice.rs
8416 views
use AnyValue::Null;1use polars_core::POOL;2use polars_core::prelude::*;3use polars_core::utils::{CustomIterTools, slice_offsets};4use polars_utils::idx_vec::IdxVec;5use rayon::prelude::*;67use super::*;8use crate::expressions::{AggregationContext, PhysicalExpr};910pub struct SliceExpr {11pub(crate) input: Arc<dyn PhysicalExpr>,12pub(crate) offset: Arc<dyn PhysicalExpr>,13pub(crate) length: Arc<dyn PhysicalExpr>,14pub(crate) expr: Expr,15}1617fn extract_offset(offset: &Column, expr: &Expr) -> PolarsResult<i64> {18polars_ensure!(19offset.len() <= 1, expr = expr, ComputeError:20"invalid argument to slice; expected an offset literal, got series of length {}",21offset.len()22);23offset.get(0).unwrap().extract().ok_or_else(24|| polars_err!(expr = expr, ComputeError: "unable to extract offset from {:?}", offset),25)26}2728fn extract_length(length: &Column, expr: &Expr) -> PolarsResult<usize> {29polars_ensure!(30length.len() <= 1, expr = expr, ComputeError:31"invalid argument to slice; expected a length literal, got series of length {}",32length.len()33);34match length.get(0).unwrap() {35Null => Ok(usize::MAX),36v => v.extract().ok_or_else(37|| polars_err!(expr = expr, ComputeError: "unable to extract length from {:?}", length),38),39}40}4142fn extract_args(offset: &Column, length: &Column, expr: &Expr) -> PolarsResult<(i64, usize)> {43Ok((extract_offset(offset, expr)?, extract_length(length, expr)?))44}4546fn check_argument(arg: &Column, groups: &GroupsType, name: &str, expr: &Expr) -> PolarsResult<()> {47polars_ensure!(48!matches!(arg.dtype(), DataType::List(_)), expr = expr, ComputeError:49"invalid slice argument: cannot use an array as {} argument", name,50);51polars_ensure!(52arg.len() == groups.len(), expr = expr, ComputeError:53"invalid slice argument: the evaluated length expression was \54of different {} than the number of groups", name55);56polars_ensure!(57arg.null_count() == 0, expr = expr, ComputeError:58"invalid slice argument: the {} expression has nulls", name59);60Ok(())61}6263fn slice_groups_idx(offset: i64, length: usize, mut first: IdxSize, idx: &[IdxSize]) -> IdxItem {64let (offset, len) = slice_offsets(offset, length, idx.len());6566// If slice isn't out of bounds, we replace first.67// If slice is oob, the `idx` vec will be empty and `first` will be ignored68if let Some(f) = idx.get(offset) {69first = *f;70}71// This is a clone of the vec, which is unfortunate. Maybe we have a `sliceable` unitvec one day.72(first, IdxVec::from_slice(&idx[offset..offset + len]))73}7475fn slice_groups_slice(offset: i64, length: usize, first: IdxSize, len: IdxSize) -> [IdxSize; 2] {76let (offset, len) = slice_offsets(offset, length, len as usize);77[first + offset as IdxSize, len as IdxSize]78}7980impl PhysicalExpr for SliceExpr {81fn as_expression(&self) -> Option<&Expr> {82Some(&self.expr)83}8485fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {86let results = POOL.install(|| {87[&self.offset, &self.length, &self.input]88.par_iter()89.map(|e| e.evaluate(df, state))90.collect::<PolarsResult<Vec<_>>>()91})?;92let offset = &results[0];93let length = &results[1];94let series = &results[2];95let (offset, length) = extract_args(offset, length, &self.expr)?;9697Ok(series.slice(offset, length))98}99100fn evaluate_on_groups<'a>(101&self,102df: &DataFrame,103groups: &'a GroupPositions,104state: &ExecutionState,105) -> PolarsResult<AggregationContext<'a>> {106let mut results = POOL.install(|| {107[&self.offset, &self.length, &self.input]108.par_iter()109.map(|e| e.evaluate_on_groups(df, groups, state))110.collect::<PolarsResult<Vec<_>>>()111})?;112let mut ac = results.pop().unwrap();113114let mut ac_length = results.pop().unwrap();115let mut ac_offset = results.pop().unwrap();116117// Fast path:118// When `input` (ac) is a LiteralValue, and both `offset` and `length` are LiteralScalar,119// we slice the LiteralValue and avoid calling groups().120// TODO: When `input` (ac) is a LiteralValue, and `offset` or `length` is not a LiteralScalar,121// we can simplify the groups calculation since we have a List containing one scalar for122// each group.123124use AggState::*;125let groups = match (&ac_offset.state, &ac_length.state) {126(LiteralScalar(offset), LiteralScalar(length)) => {127let (offset, length) = extract_args(offset, length, &self.expr)?;128129if let LiteralScalar(s) = ac.agg_state() {130let s1 = s.slice(offset, length);131ac.with_literal(s1);132ac.aggregated();133return Ok(ac);134}135if let AggregatedScalar(c) = ac.state {136ac.state = AggregatedList(c.as_list().into_column());137ac.update_groups = UpdateGroups::WithSeriesLen;138}139let groups = ac.groups();140141match groups.as_ref().as_ref() {142GroupsType::Idx(groups) => {143let groups = groups144.iter()145.map(|(first, idx)| slice_groups_idx(offset, length, first, idx))146.collect();147GroupsType::Idx(groups)148},149GroupsType::Slice {150groups,151overlapping,152monotonic,153} => {154let groups = groups155.iter()156.map(|&[first, len]| slice_groups_slice(offset, length, first, len))157.collect_trusted();158GroupsType::new_slice(groups, *overlapping, *monotonic)159},160}161},162(LiteralScalar(offset), _) => {163if matches!(ac.state, LiteralScalar(_)) {164ac.aggregated();165} else if let AggregatedScalar(c) = ac.state {166ac.state = AggregatedList(c.as_list().into_column());167ac.update_groups = UpdateGroups::WithSeriesLen;168}169let groups = ac.groups();170let offset = extract_offset(offset, &self.expr)?;171let length = ac_length.aggregated();172check_argument(&length, groups, "length", &self.expr)?;173174let length = length.cast(&IDX_DTYPE)?;175let length = length.idx().unwrap();176177match groups.as_ref().as_ref() {178GroupsType::Idx(groups) => {179let groups = groups180.iter()181.zip(length.into_no_null_iter())182.map(|((first, idx), length)| {183slice_groups_idx(offset, length as usize, first, idx)184})185.collect();186GroupsType::Idx(groups)187},188GroupsType::Slice {189groups,190overlapping,191monotonic: _,192} => {193let groups = groups194.iter()195.zip(length.into_no_null_iter())196.map(|(&[first, len], length)| {197slice_groups_slice(offset, length as usize, first, len)198})199.collect_trusted();200GroupsType::new_slice(groups, *overlapping, false)201},202}203},204(_, LiteralScalar(length)) => {205if matches!(ac.state, LiteralScalar(_)) {206ac.aggregated();207} else if let AggregatedScalar(c) = ac.state {208ac.state = AggregatedList(c.as_list().into_column());209ac.update_groups = UpdateGroups::WithSeriesLen;210}211let groups = ac.groups();212let length = extract_length(length, &self.expr)?;213let offset = ac_offset.aggregated();214check_argument(&offset, groups, "offset", &self.expr)?;215216let offset = offset.cast(&DataType::Int64)?;217let offset = offset.i64().unwrap();218219match groups.as_ref().as_ref() {220GroupsType::Idx(groups) => {221let groups = groups222.iter()223.zip(offset.into_no_null_iter())224.map(|((first, idx), offset)| {225slice_groups_idx(offset, length, first, idx)226})227.collect();228GroupsType::Idx(groups)229},230GroupsType::Slice {231groups,232overlapping,233monotonic: _,234} => {235let groups = groups236.iter()237.zip(offset.into_no_null_iter())238.map(|(&[first, len], offset)| {239slice_groups_slice(offset, length, first, len)240})241.collect_trusted();242GroupsType::new_slice(groups, *overlapping, false)243},244}245},246_ => {247if matches!(ac.state, LiteralScalar(_)) {248ac.aggregated();249} else if let AggregatedScalar(c) = ac.state {250ac.state = AggregatedList(c.as_list().into_column());251ac.update_groups = UpdateGroups::WithSeriesLen;252}253254let groups = ac.groups();255let length = ac_length.aggregated();256let offset = ac_offset.aggregated();257check_argument(&length, groups, "length", &self.expr)?;258check_argument(&offset, groups, "offset", &self.expr)?;259260let offset = offset.cast(&DataType::Int64)?;261let offset = offset.i64().unwrap();262263let length = length.cast(&IDX_DTYPE)?;264let length = length.idx().unwrap();265266match groups.as_ref().as_ref() {267GroupsType::Idx(groups) => {268let groups = groups269.iter()270.zip(offset.into_no_null_iter())271.zip(length.into_no_null_iter())272.map(|(((first, idx), offset), length)| {273slice_groups_idx(offset, length as usize, first, idx)274})275.collect();276GroupsType::Idx(groups)277},278GroupsType::Slice {279groups,280overlapping,281monotonic: _,282} => {283let groups = groups284.iter()285.zip(offset.into_no_null_iter())286.zip(length.into_no_null_iter())287.map(|((&[first, len], offset), length)| {288slice_groups_slice(offset, length as usize, first, len)289})290.collect_trusted();291GroupsType::new_slice(groups, *overlapping, false)292},293}294},295};296297ac.with_groups(groups.into_sliceable())298.set_original_len(false);299300Ok(ac)301}302303fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {304self.input.to_field(input_schema)305}306307fn is_scalar(&self) -> bool {308false309}310}311312313