Path: blob/main/crates/polars-expr/src/expressions/slice.rs
6940 views
use AnyValue::Null;1use polars_core::POOL;2use polars_core::prelude::*;3use polars_core::utils::{CustomIterTools, slice_offsets};4use rayon::prelude::*;56use super::*;7use crate::expressions::{AggregationContext, PhysicalExpr};89pub struct SliceExpr {10pub(crate) input: Arc<dyn PhysicalExpr>,11pub(crate) offset: Arc<dyn PhysicalExpr>,12pub(crate) length: Arc<dyn PhysicalExpr>,13pub(crate) expr: Expr,14}1516fn extract_offset(offset: &Column, expr: &Expr) -> PolarsResult<i64> {17polars_ensure!(18offset.len() <= 1, expr = expr, ComputeError:19"invalid argument to slice; expected an offset literal, got series of length {}",20offset.len()21);22offset.get(0).unwrap().extract().ok_or_else(23|| polars_err!(expr = expr, ComputeError: "unable to extract offset from {:?}", offset),24)25}2627fn extract_length(length: &Column, expr: &Expr) -> PolarsResult<usize> {28polars_ensure!(29length.len() <= 1, expr = expr, ComputeError:30"invalid argument to slice; expected a length literal, got series of length {}",31length.len()32);33match length.get(0).unwrap() {34Null => Ok(usize::MAX),35v => v.extract().ok_or_else(36|| polars_err!(expr = expr, ComputeError: "unable to extract length from {:?}", length),37),38}39}4041fn extract_args(offset: &Column, length: &Column, expr: &Expr) -> PolarsResult<(i64, usize)> {42Ok((extract_offset(offset, expr)?, extract_length(length, expr)?))43}4445fn check_argument(arg: &Column, groups: &GroupsType, name: &str, expr: &Expr) -> PolarsResult<()> {46polars_ensure!(47!matches!(arg.dtype(), DataType::List(_)), expr = expr, ComputeError:48"invalid slice argument: cannot use an array as {} argument", name,49);50polars_ensure!(51arg.len() == groups.len(), expr = expr, ComputeError:52"invalid slice argument: the evaluated length expression was \53of different {} than the number of groups", name54);55polars_ensure!(56arg.null_count() == 0, expr = expr, ComputeError:57"invalid slice argument: the {} expression has nulls", name58);59Ok(())60}6162fn slice_groups_idx(offset: i64, length: usize, mut first: IdxSize, idx: &[IdxSize]) -> IdxItem {63let (offset, len) = slice_offsets(offset, length, idx.len());6465// If slice isn't out of bounds, we replace first.66// If slice is oob, the `idx` vec will be empty and `first` will be ignored67if let Some(f) = idx.get(offset) {68first = *f;69}70// This is a clone of the vec, which is unfortunate. Maybe we have a `sliceable` unitvec one day.71(first, idx[offset..offset + len].into())72}7374fn slice_groups_slice(offset: i64, length: usize, first: IdxSize, len: IdxSize) -> [IdxSize; 2] {75let (offset, len) = slice_offsets(offset, length, len as usize);76[first + offset as IdxSize, len as IdxSize]77}7879impl PhysicalExpr for SliceExpr {80fn as_expression(&self) -> Option<&Expr> {81Some(&self.expr)82}8384fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {85let results = POOL.install(|| {86[&self.offset, &self.length, &self.input]87.par_iter()88.map(|e| e.evaluate(df, state))89.collect::<PolarsResult<Vec<_>>>()90})?;91let offset = &results[0];92let length = &results[1];93let series = &results[2];94let (offset, length) = extract_args(offset, length, &self.expr)?;9596Ok(series.slice(offset, length))97}9899fn evaluate_on_groups<'a>(100&self,101df: &DataFrame,102groups: &'a GroupPositions,103state: &ExecutionState,104) -> PolarsResult<AggregationContext<'a>> {105let mut results = POOL.install(|| {106[&self.offset, &self.length, &self.input]107.par_iter()108.map(|e| e.evaluate_on_groups(df, groups, state))109.collect::<PolarsResult<Vec<_>>>()110})?;111let mut ac = results.pop().unwrap();112113if let AggState::AggregatedScalar(_) = ac.agg_state() {114polars_bail!(InvalidOperation: "cannot slice() an aggregated scalar value")115}116117let mut ac_length = results.pop().unwrap();118let mut ac_offset = results.pop().unwrap();119120// Fast path:121// When `input` (ac) is a LiteralValue, and both `offset` and `length` are LiteralScalar,122// we slice the LiteralValue and avoid calling groups().123// TODO: When `input` (ac) is a LiteralValue, and `offset` or `length` is not a LiteralScalar,124// we can simplify the groups calculation since we have a List containing one scalar for125// each group.126127use AggState::*;128let groups = match (&ac_offset.state, &ac_length.state) {129(LiteralScalar(offset), LiteralScalar(length)) => {130let (offset, length) = extract_args(offset, length, &self.expr)?;131132if let LiteralScalar(s) = ac.agg_state() {133let s1 = s.slice(offset, length);134ac.with_literal(s1);135ac.aggregated();136return Ok(ac);137}138let groups = ac.groups();139140match groups.as_ref().as_ref() {141GroupsType::Idx(groups) => {142let groups = groups143.iter()144.map(|(first, idx)| slice_groups_idx(offset, length, first, idx))145.collect();146GroupsType::Idx(groups)147},148GroupsType::Slice { groups, .. } => {149let groups = groups150.iter()151.map(|&[first, len]| slice_groups_slice(offset, length, first, len))152.collect_trusted();153GroupsType::Slice {154groups,155rolling: false,156}157},158}159},160(LiteralScalar(offset), _) => {161if matches!(ac.state, LiteralScalar(_)) {162ac.aggregated();163}164let groups = ac.groups();165let offset = extract_offset(offset, &self.expr)?;166let length = ac_length.aggregated();167check_argument(&length, groups, "length", &self.expr)?;168169let length = length.cast(&IDX_DTYPE)?;170let length = length.idx().unwrap();171172match groups.as_ref().as_ref() {173GroupsType::Idx(groups) => {174let groups = groups175.iter()176.zip(length.into_no_null_iter())177.map(|((first, idx), length)| {178slice_groups_idx(offset, length as usize, first, idx)179})180.collect();181GroupsType::Idx(groups)182},183GroupsType::Slice { groups, .. } => {184let groups = groups185.iter()186.zip(length.into_no_null_iter())187.map(|(&[first, len], length)| {188slice_groups_slice(offset, length as usize, first, len)189})190.collect_trusted();191GroupsType::Slice {192groups,193rolling: false,194}195},196}197},198(_, LiteralScalar(length)) => {199if matches!(ac.state, LiteralScalar(_)) {200ac.aggregated();201}202let groups = ac.groups();203let length = extract_length(length, &self.expr)?;204let offset = ac_offset.aggregated();205check_argument(&offset, groups, "offset", &self.expr)?;206207let offset = offset.cast(&DataType::Int64)?;208let offset = offset.i64().unwrap();209210match groups.as_ref().as_ref() {211GroupsType::Idx(groups) => {212let groups = groups213.iter()214.zip(offset.into_no_null_iter())215.map(|((first, idx), offset)| {216slice_groups_idx(offset, length, first, idx)217})218.collect();219GroupsType::Idx(groups)220},221GroupsType::Slice { groups, .. } => {222let groups = groups223.iter()224.zip(offset.into_no_null_iter())225.map(|(&[first, len], offset)| {226slice_groups_slice(offset, length, first, len)227})228.collect_trusted();229GroupsType::Slice {230groups,231rolling: false,232}233},234}235},236_ => {237if matches!(ac.state, LiteralScalar(_)) {238ac.aggregated();239}240241let groups = ac.groups();242let length = ac_length.aggregated();243let offset = ac_offset.aggregated();244check_argument(&length, groups, "length", &self.expr)?;245check_argument(&offset, groups, "offset", &self.expr)?;246247let offset = offset.cast(&DataType::Int64)?;248let offset = offset.i64().unwrap();249250let length = length.cast(&IDX_DTYPE)?;251let length = length.idx().unwrap();252253match groups.as_ref().as_ref() {254GroupsType::Idx(groups) => {255let groups = groups256.iter()257.zip(offset.into_no_null_iter())258.zip(length.into_no_null_iter())259.map(|(((first, idx), offset), length)| {260slice_groups_idx(offset, length as usize, first, idx)261})262.collect();263GroupsType::Idx(groups)264},265GroupsType::Slice { groups, .. } => {266let groups = groups267.iter()268.zip(offset.into_no_null_iter())269.zip(length.into_no_null_iter())270.map(|((&[first, len], offset), length)| {271slice_groups_slice(offset, length as usize, first, len)272})273.collect_trusted();274GroupsType::Slice {275groups,276rolling: false,277}278},279}280},281};282283ac.with_groups(groups.into_sliceable())284.set_original_len(false);285286Ok(ac)287}288289fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {290self.input.to_field(input_schema)291}292293fn is_scalar(&self) -> bool {294false295}296}297298299