Path: blob/main/crates/polars-expr/src/expressions/sortby.rs
8422 views
use polars_core::POOL;1use polars_core::chunked_array::from_iterator_par::ChunkedCollectParIterExt;2use polars_core::prelude::*;3use polars_utils::idx_vec::IdxVec;4use rayon::prelude::*;56use super::*;7use crate::expressions::{8AggregationContext, PhysicalExpr, UpdateGroups, map_sorted_indices_to_group_idx,9map_sorted_indices_to_group_slice,10};1112pub struct SortByExpr {13pub(crate) input: Arc<dyn PhysicalExpr>,14pub(crate) by: Vec<Arc<dyn PhysicalExpr>>,15pub(crate) expr: Expr,16pub(crate) sort_options: SortMultipleOptions,17}1819impl SortByExpr {20pub fn new(21input: Arc<dyn PhysicalExpr>,22by: Vec<Arc<dyn PhysicalExpr>>,23expr: Expr,24sort_options: SortMultipleOptions,25) -> Self {26Self {27input,28by,29expr,30sort_options,31}32}33}3435fn prepare_bool_vec(values: &[bool], by_len: usize) -> Vec<bool> {36match (values.len(), by_len) {37// Equal length.38(n_rvalues, n) if n_rvalues == n => values.to_vec(),39// None given all false.40(0, n) => vec![false; n],41// Broadcast first.42(_, n) => vec![values[0]; n],43}44}4546static ERR_MSG: &str = "expressions in 'sort_by' must have matching group lengths";4748fn check_groups(a: &GroupsType, b: &GroupsType) -> PolarsResult<()> {49polars_ensure!(a.iter().zip(b.iter()).all(|(a, b)| {50a.len() == b.len()51}), ShapeMismatch: ERR_MSG);52Ok(())53}5455pub(super) fn update_groups_sort_by(56groups: &GroupsType,57sort_by_s: &Series,58options: &SortOptions,59) -> PolarsResult<GroupsType> {60// Will trigger a gather for every group, so rechunk before.61let sort_by_s = sort_by_s.rechunk();62let groups = POOL.install(|| {63groups64.par_iter()65.map(|indicator| sort_by_groups_single_by(indicator, &sort_by_s, options))66.collect::<PolarsResult<_>>()67})?;6869Ok(GroupsType::Idx(groups))70}7172fn sort_by_groups_single_by(73indicator: GroupsIndicator,74sort_by_s: &Series,75options: &SortOptions,76) -> PolarsResult<(IdxSize, IdxVec)> {77let options = SortOptions {78descending: options.descending,79nulls_last: options.nulls_last,80// We are already in par iter.81multithreaded: false,82..Default::default()83};84let new_idx = match indicator {85GroupsIndicator::Idx((_, idx)) => {86// SAFETY: group tuples are always in bounds.87let group = unsafe { sort_by_s.take_slice_unchecked(idx) };8889let sorted_idx = group.arg_sort(options);90map_sorted_indices_to_group_idx(&sorted_idx, idx)91},92GroupsIndicator::Slice([first, len]) => {93let group = sort_by_s.slice(first as i64, len as usize);94let sorted_idx = group.arg_sort(options);95map_sorted_indices_to_group_slice(&sorted_idx, first)96},97};9899let first = new_idx.first().unwrap_or(&0);100Ok((*first, new_idx))101}102103fn sort_by_groups_no_match_single<'a>(104mut ac_in: AggregationContext<'a>,105mut ac_by: AggregationContext<'a>,106descending: bool,107expr: &Expr,108) -> PolarsResult<AggregationContext<'a>> {109let s_in = ac_in.aggregated();110let s_by = ac_by.aggregated();111let mut s_in = s_in.list().unwrap().clone();112let mut s_by = s_by.list().unwrap().clone();113114let dtype = s_in.dtype().clone();115let ca: PolarsResult<ListChunked> = POOL.install(|| {116s_in.par_iter_indexed()117.zip(s_by.par_iter_indexed())118.map(|(opt_s, s_sort_by)| match (opt_s, s_sort_by) {119(Some(s), Some(s_sort_by)) => {120polars_ensure!(s.len() == s_sort_by.len(), ComputeError: "series lengths don't match in 'sort_by' expression");121let idx = s_sort_by.arg_sort(SortOptions {122descending,123// We are already in par iter.124multithreaded: false,125..Default::default()126});127Ok(Some(unsafe { s.take_unchecked(&idx) }))128},129_ => Ok(None),130})131.collect_ca_with_dtype(PlSmallStr::EMPTY, dtype)132});133let c = ca?.with_name(s_in.name().clone()).into_column();134ac_in.with_values(c, true, Some(expr))?;135Ok(ac_in)136}137138fn sort_by_groups_multiple_by(139indicator: GroupsIndicator,140sort_by_s: &[Series],141descending: &[bool],142nulls_last: &[bool],143multithreaded: bool,144maintain_order: bool,145) -> PolarsResult<(IdxSize, IdxVec)> {146let new_idx = match indicator {147GroupsIndicator::Idx((_first, idx)) => {148// SAFETY: group tuples are always in bounds.149let groups = sort_by_s150.iter()151.map(|s| unsafe { s.take_slice_unchecked(idx) })152.map(Column::from)153.collect::<Vec<_>>();154155let options = SortMultipleOptions {156descending: descending.to_owned(),157nulls_last: nulls_last.to_owned(),158multithreaded,159maintain_order,160limit: None,161};162163let sorted_idx = groups[0]164.as_materialized_series()165.arg_sort_multiple(&groups[1..], &options)166.unwrap();167map_sorted_indices_to_group_idx(&sorted_idx, idx)168},169GroupsIndicator::Slice([first, len]) => {170let groups = sort_by_s171.iter()172.map(|s| s.slice(first as i64, len as usize))173.map(Column::from)174.collect::<Vec<_>>();175176let options = SortMultipleOptions {177descending: descending.to_owned(),178nulls_last: nulls_last.to_owned(),179multithreaded,180maintain_order,181limit: None,182};183let sorted_idx = groups[0]184.as_materialized_series()185.arg_sort_multiple(&groups[1..], &options)186.unwrap();187map_sorted_indices_to_group_slice(&sorted_idx, first)188},189};190let first = new_idx191.first()192.ok_or_else(|| polars_err!(ComputeError: "{ERR_MSG}"))?;193194Ok((*first, new_idx))195}196197impl PhysicalExpr for SortByExpr {198fn as_expression(&self) -> Option<&Expr> {199Some(&self.expr)200}201202fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {203let series_f = || self.input.evaluate(df, state);204if self.by.is_empty() {205// Sorting by 0 columns returns input unchanged.206return series_f();207}208let (series, sorted_idx) = if self.by.len() == 1 {209let sorted_idx_f = || {210let s_sort_by = self.by[0].evaluate(df, state)?;211Ok(s_sort_by.arg_sort(SortOptions::from(&self.sort_options)))212};213POOL.install(|| rayon::join(series_f, sorted_idx_f))214} else {215let descending = prepare_bool_vec(&self.sort_options.descending, self.by.len());216let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len());217218let sorted_idx_f = || {219let mut needs_broadcast = false;220let mut broadcast_length = 1;221222let mut s_sort_by = self223.by224.iter()225.enumerate()226.map(|(i, e)| {227let column = e.evaluate(df, state).map(|c| match c.dtype() {228#[cfg(feature = "dtype-categorical")]229DataType::Categorical(_, _) | DataType::Enum(_, _) => c,230_ => c.to_physical_repr(),231})?;232233if column.len() == 1 && broadcast_length != 1 {234polars_ensure!(235e.is_scalar(),236ShapeMismatch: "non-scalar expression produces broadcasting column",237);238239return Ok(column.new_from_index(0, broadcast_length));240}241242if broadcast_length != column.len() {243polars_ensure!(244broadcast_length == 1, ShapeMismatch:245"`sort_by` produced different length ({}) than earlier Series' length in `by` ({})",246broadcast_length, column.len()247);248249needs_broadcast |= i > 0;250broadcast_length = column.len();251}252253Ok(column)254})255.collect::<PolarsResult<Vec<_>>>()?;256257if needs_broadcast {258for c in s_sort_by.iter_mut() {259if c.len() != broadcast_length {260*c = c.new_from_index(0, broadcast_length);261}262}263}264265let options = self266.sort_options267.clone()268.with_order_descending_multi(descending)269.with_nulls_last_multi(nulls_last);270271s_sort_by[0]272.as_materialized_series()273.arg_sort_multiple(&s_sort_by[1..], &options)274};275POOL.install(|| rayon::join(series_f, sorted_idx_f))276};277let (sorted_idx, series) = (sorted_idx?, series?);278polars_ensure!(279sorted_idx.len() == series.len(),280expr = self.expr, ShapeMismatch:281"`sort_by` produced different length ({}) than the Series that has to be sorted ({})",282sorted_idx.len(), series.len()283);284285// SAFETY: sorted index are within bounds.286unsafe { Ok(series.take_unchecked(&sorted_idx)) }287}288289#[allow(clippy::ptr_arg)]290fn evaluate_on_groups<'a>(291&self,292df: &DataFrame,293groups: &'a GroupPositions,294state: &ExecutionState,295) -> PolarsResult<AggregationContext<'a>> {296let mut ac_in = self.input.evaluate_on_groups(df, groups, state)?;297let descending = prepare_bool_vec(&self.sort_options.descending, self.by.len());298let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len());299300let mut ac_sort_by = self301.by302.iter()303.map(|e| e.evaluate_on_groups(df, groups, state))304.collect::<PolarsResult<Vec<_>>>()?;305306assert!(307ac_sort_by308.iter()309.all(|ac_sort_by| ac_sort_by.groups.len() == ac_in.groups.len())310);311312// Enable reliable length checks downstream313ac_in.set_groups_for_undefined_agg_states();314ac_sort_by315.iter_mut()316.for_each(|ac| ac.set_groups_for_undefined_agg_states());317318// If every input is a LiteralScalar, we return a LiteralScalar.319// Otherwise, we convert any LiteralScalar to AggregatedList.320let all_literal = matches!(ac_in.state, AggState::LiteralScalar(_))321|| ac_sort_by322.iter()323.all(|ac| matches!(ac.state, AggState::LiteralScalar(_)));324325if all_literal {326return Ok(ac_in);327} else {328if matches!(ac_in.state, AggState::LiteralScalar(_)) {329ac_in.aggregated();330}331for ac in ac_sort_by.iter_mut() {332if matches!(ac.state, AggState::LiteralScalar(_)) {333ac.aggregated();334}335}336}337338let mut sort_by_s = ac_sort_by339.iter()340.map(|c| {341let c = c.flat_naive();342match c.dtype() {343#[cfg(feature = "dtype-categorical")]344DataType::Categorical(_, _) | DataType::Enum(_, _) => {345c.as_materialized_series().clone()346},347// @scalar-opt348// @partition-opt349_ => c.to_physical_repr().take_materialized_series(),350}351})352.collect::<Vec<_>>();353354let ordered_by_group_operation = matches!(355ac_sort_by[0].update_groups,356UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen357);358359let groups = if self.by.len() == 1 {360let mut ac_sort_by = ac_sort_by.pop().unwrap();361362// The groups of the lhs of the expressions do not match the series values,363// we must take the slower path.364if !matches!(ac_in.update_groups, UpdateGroups::No) {365return sort_by_groups_no_match_single(366ac_in,367ac_sort_by,368self.sort_options.descending[0],369&self.expr,370);371};372373let sort_by_s = sort_by_s.pop().unwrap();374let groups = ac_sort_by.groups();375376let (check, groups) = POOL.join(377|| check_groups(groups, ac_in.groups()),378|| {379update_groups_sort_by(380groups,381&sort_by_s,382&SortOptions {383descending: descending[0],384nulls_last: nulls_last[0],385..Default::default()386},387)388},389);390check?;391392groups?393} else {394let groups = ac_sort_by[0].groups();395396let groups = POOL.install(|| {397groups398.par_iter()399.map(|indicator| {400sort_by_groups_multiple_by(401indicator,402&sort_by_s,403&descending,404&nulls_last,405self.sort_options.multithreaded,406self.sort_options.maintain_order,407)408})409.collect::<PolarsResult<_>>()410});411GroupsType::Idx(groups?)412};413414// If the rhs is already aggregated once, it is reordered by the415// group_by operation - we must ensure that we are as well.416if ordered_by_group_operation {417let s = ac_in.aggregated();418ac_in.with_values(419s.explode(ExplodeOptions {420empty_as_null: true,421keep_nulls: true,422})423.unwrap(),424false,425None,426)?;427}428429ac_in.with_groups(groups.into_sliceable());430Ok(ac_in)431}432433fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {434self.input.to_field(input_schema)435}436437fn is_scalar(&self) -> bool {438false439}440}441442443