Path: blob/main/crates/polars-expr/src/expressions/sortby.rs
6940 views
use polars_core::POOL;1use polars_core::chunked_array::from_iterator_par::ChunkedCollectParIterExt;2use polars_core::prelude::*;3use polars_utils::idx_vec::IdxVec;4use rayon::prelude::*;56use super::*;7use crate::expressions::{8AggregationContext, PhysicalExpr, UpdateGroups, map_sorted_indices_to_group_idx,9map_sorted_indices_to_group_slice,10};1112pub struct SortByExpr {13pub(crate) input: Arc<dyn PhysicalExpr>,14pub(crate) by: Vec<Arc<dyn PhysicalExpr>>,15pub(crate) expr: Expr,16pub(crate) sort_options: SortMultipleOptions,17}1819impl SortByExpr {20pub fn new(21input: Arc<dyn PhysicalExpr>,22by: Vec<Arc<dyn PhysicalExpr>>,23expr: Expr,24sort_options: SortMultipleOptions,25) -> Self {26Self {27input,28by,29expr,30sort_options,31}32}33}3435fn prepare_bool_vec(values: &[bool], by_len: usize) -> Vec<bool> {36match (values.len(), by_len) {37// Equal length.38(n_rvalues, n) if n_rvalues == n => values.to_vec(),39// None given all false.40(0, n) => vec![false; n],41// Broadcast first.42(_, n) => vec![values[0]; n],43}44}4546static ERR_MSG: &str = "expressions in 'sort_by' must have matching group lengths";4748fn check_groups(a: &GroupsType, b: &GroupsType) -> PolarsResult<()> {49polars_ensure!(a.iter().zip(b.iter()).all(|(a, b)| {50a.len() == b.len()51}), ComputeError: ERR_MSG);52Ok(())53}5455pub(super) fn update_groups_sort_by(56groups: &GroupsType,57sort_by_s: &Series,58options: &SortOptions,59) -> PolarsResult<GroupsType> {60// Will trigger a gather for every group, so rechunk before.61let sort_by_s = sort_by_s.rechunk();62let groups = POOL.install(|| {63groups64.par_iter()65.map(|indicator| sort_by_groups_single_by(indicator, &sort_by_s, options))66.collect::<PolarsResult<_>>()67})?;6869Ok(GroupsType::Idx(groups))70}7172fn sort_by_groups_single_by(73indicator: GroupsIndicator,74sort_by_s: &Series,75options: &SortOptions,76) -> PolarsResult<(IdxSize, IdxVec)> {77let options = SortOptions {78descending: options.descending,79nulls_last: options.nulls_last,80// We are already in par iter.81multithreaded: false,82..Default::default()83};84let new_idx = match indicator {85GroupsIndicator::Idx((_, idx)) => {86// SAFETY: group tuples are always in bounds.87let group = unsafe { sort_by_s.take_slice_unchecked(idx) };8889let sorted_idx = group.arg_sort(options);90map_sorted_indices_to_group_idx(&sorted_idx, idx)91},92GroupsIndicator::Slice([first, len]) => {93let group = sort_by_s.slice(first as i64, len as usize);94let sorted_idx = group.arg_sort(options);95map_sorted_indices_to_group_slice(&sorted_idx, first)96},97};98let first = new_idx99.first()100.ok_or_else(|| polars_err!(ComputeError: "{}", ERR_MSG))?;101102Ok((*first, new_idx))103}104105fn sort_by_groups_no_match_single<'a>(106mut ac_in: AggregationContext<'a>,107mut ac_by: AggregationContext<'a>,108descending: bool,109expr: &Expr,110) -> PolarsResult<AggregationContext<'a>> {111let s_in = ac_in.aggregated();112let s_by = ac_by.aggregated();113let mut s_in = s_in.list().unwrap().clone();114let mut s_by = s_by.list().unwrap().clone();115116let dtype = s_in.dtype().clone();117let ca: PolarsResult<ListChunked> = POOL.install(|| {118s_in.par_iter_indexed()119.zip(s_by.par_iter_indexed())120.map(|(opt_s, s_sort_by)| match (opt_s, s_sort_by) {121(Some(s), Some(s_sort_by)) => {122polars_ensure!(s.len() == s_sort_by.len(), ComputeError: "series lengths don't match in 'sort_by' expression");123let idx = s_sort_by.arg_sort(SortOptions {124descending,125// We are already in par iter.126multithreaded: false,127..Default::default()128});129Ok(Some(unsafe { s.take_unchecked(&idx) }))130},131_ => Ok(None),132})133.collect_ca_with_dtype(PlSmallStr::EMPTY, dtype)134});135let c = ca?.with_name(s_in.name().clone()).into_column();136ac_in.with_values(c, true, Some(expr))?;137Ok(ac_in)138}139140fn sort_by_groups_multiple_by(141indicator: GroupsIndicator,142sort_by_s: &[Series],143descending: &[bool],144nulls_last: &[bool],145multithreaded: bool,146maintain_order: bool,147) -> PolarsResult<(IdxSize, IdxVec)> {148let new_idx = match indicator {149GroupsIndicator::Idx((_first, idx)) => {150// SAFETY: group tuples are always in bounds.151let groups = sort_by_s152.iter()153.map(|s| unsafe { s.take_slice_unchecked(idx) })154.map(Column::from)155.collect::<Vec<_>>();156157let options = SortMultipleOptions {158descending: descending.to_owned(),159nulls_last: nulls_last.to_owned(),160multithreaded,161maintain_order,162limit: None,163};164165let sorted_idx = groups[0]166.as_materialized_series()167.arg_sort_multiple(&groups[1..], &options)168.unwrap();169map_sorted_indices_to_group_idx(&sorted_idx, idx)170},171GroupsIndicator::Slice([first, len]) => {172let groups = sort_by_s173.iter()174.map(|s| s.slice(first as i64, len as usize))175.map(Column::from)176.collect::<Vec<_>>();177178let options = SortMultipleOptions {179descending: descending.to_owned(),180nulls_last: nulls_last.to_owned(),181multithreaded,182maintain_order,183limit: None,184};185let sorted_idx = groups[0]186.as_materialized_series()187.arg_sort_multiple(&groups[1..], &options)188.unwrap();189map_sorted_indices_to_group_slice(&sorted_idx, first)190},191};192let first = new_idx193.first()194.ok_or_else(|| polars_err!(ComputeError: "{}", ERR_MSG))?;195196Ok((*first, new_idx))197}198199impl PhysicalExpr for SortByExpr {200fn as_expression(&self) -> Option<&Expr> {201Some(&self.expr)202}203204fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {205let series_f = || self.input.evaluate(df, state);206if self.by.is_empty() {207// Sorting by 0 columns returns input unchanged.208return series_f();209}210let (series, sorted_idx) = if self.by.len() == 1 {211let sorted_idx_f = || {212let s_sort_by = self.by[0].evaluate(df, state)?;213Ok(s_sort_by.arg_sort(SortOptions::from(&self.sort_options)))214};215POOL.install(|| rayon::join(series_f, sorted_idx_f))216} else {217let descending = prepare_bool_vec(&self.sort_options.descending, self.by.len());218let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len());219220let sorted_idx_f = || {221let mut needs_broadcast = false;222let mut broadcast_length = 1;223224let mut s_sort_by = self225.by226.iter()227.enumerate()228.map(|(i, e)| {229let column = e.evaluate(df, state).map(|c| match c.dtype() {230#[cfg(feature = "dtype-categorical")]231DataType::Categorical(_, _) | DataType::Enum(_, _) => c,232_ => c.to_physical_repr(),233})?;234235if column.len() == 1 && broadcast_length != 1 {236polars_ensure!(237e.is_scalar(),238ShapeMismatch: "non-scalar expression produces broadcasting column",239);240241return Ok(column.new_from_index(0, broadcast_length));242}243244if broadcast_length != column.len() {245polars_ensure!(246broadcast_length == 1, ShapeMismatch:247"`sort_by` produced different length ({}) than earlier Series' length in `by` ({})",248broadcast_length, column.len()249);250251needs_broadcast |= i > 0;252broadcast_length = column.len();253}254255Ok(column)256})257.collect::<PolarsResult<Vec<_>>>()?;258259if needs_broadcast {260for c in s_sort_by.iter_mut() {261if c.len() != broadcast_length {262*c = c.new_from_index(0, broadcast_length);263}264}265}266267let options = self268.sort_options269.clone()270.with_order_descending_multi(descending)271.with_nulls_last_multi(nulls_last);272273s_sort_by[0]274.as_materialized_series()275.arg_sort_multiple(&s_sort_by[1..], &options)276};277POOL.install(|| rayon::join(series_f, sorted_idx_f))278};279let (sorted_idx, series) = (sorted_idx?, series?);280polars_ensure!(281sorted_idx.len() == series.len(),282expr = self.expr, ShapeMismatch:283"`sort_by` produced different length ({}) than the Series that has to be sorted ({})",284sorted_idx.len(), series.len()285);286287// SAFETY: sorted index are within bounds.288unsafe { Ok(series.take_unchecked(&sorted_idx)) }289}290291#[allow(clippy::ptr_arg)]292fn evaluate_on_groups<'a>(293&self,294df: &DataFrame,295groups: &'a GroupPositions,296state: &ExecutionState,297) -> PolarsResult<AggregationContext<'a>> {298let mut ac_in = self.input.evaluate_on_groups(df, groups, state)?;299let descending = prepare_bool_vec(&self.sort_options.descending, self.by.len());300let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len());301302let mut ac_sort_by = self303.by304.iter()305.map(|e| e.evaluate_on_groups(df, groups, state))306.collect::<PolarsResult<Vec<_>>>()?;307308assert!(309ac_sort_by310.iter()311.all(|ac_sort_by| ac_sort_by.groups.len() == ac_in.groups.len())312);313314// If every input is a LiteralScalar, we return a LiteralScalar.315// Otherwise, we convert any LiteralScalar to AggregatedList.316let all_literal = matches!(ac_in.state, AggState::LiteralScalar(_))317|| ac_sort_by318.iter()319.all(|ac| matches!(ac.state, AggState::LiteralScalar(_)));320321if all_literal {322return Ok(ac_in);323} else {324if matches!(ac_in.state, AggState::LiteralScalar(_)) {325ac_in.aggregated();326}327for ac in ac_sort_by.iter_mut() {328if matches!(ac.state, AggState::LiteralScalar(_)) {329ac.aggregated();330}331}332}333334let mut sort_by_s = ac_sort_by335.iter()336.map(|c| {337let c = c.flat_naive();338match c.dtype() {339#[cfg(feature = "dtype-categorical")]340DataType::Categorical(_, _) | DataType::Enum(_, _) => {341c.as_materialized_series().clone()342},343// @scalar-opt344// @partition-opt345_ => c.to_physical_repr().take_materialized_series(),346}347})348.collect::<Vec<_>>();349350let ordered_by_group_operation = matches!(351ac_sort_by[0].update_groups,352UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen353);354355let groups = if self.by.len() == 1 {356let mut ac_sort_by = ac_sort_by.pop().unwrap();357358// The groups of the lhs of the expressions do not match the series values,359// we must take the slower path.360if !matches!(ac_in.update_groups, UpdateGroups::No) {361return sort_by_groups_no_match_single(362ac_in,363ac_sort_by,364self.sort_options.descending[0],365&self.expr,366);367};368369let sort_by_s = sort_by_s.pop().unwrap();370let groups = ac_sort_by.groups();371372let (check, groups) = POOL.join(373|| check_groups(groups, ac_in.groups()),374|| {375update_groups_sort_by(376groups,377&sort_by_s,378&SortOptions {379descending: descending[0],380nulls_last: nulls_last[0],381..Default::default()382},383)384},385);386check?;387388groups?389} else {390let groups = ac_sort_by[0].groups();391392let groups = POOL.install(|| {393groups394.par_iter()395.map(|indicator| {396sort_by_groups_multiple_by(397indicator,398&sort_by_s,399&descending,400&nulls_last,401self.sort_options.multithreaded,402self.sort_options.maintain_order,403)404})405.collect::<PolarsResult<_>>()406});407GroupsType::Idx(groups?)408};409410// If the rhs is already aggregated once, it is reordered by the411// group_by operation - we must ensure that we are as well.412if ordered_by_group_operation {413let s = ac_in.aggregated();414ac_in.with_values(s.explode(false).unwrap(), false, None)?;415}416417ac_in.with_groups(groups.into_sliceable());418Ok(ac_in)419}420421fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {422self.input.to_field(input_schema)423}424425fn is_scalar(&self) -> bool {426false427}428}429430431