Path: blob/main/crates/polars-expr/src/expressions/window.rs
8421 views
use std::cmp::Ordering;1use std::fmt::Write;23use arrow::array::PrimitiveArray;4use arrow::bitmap::Bitmap;5use arrow::trusted_len::TrustMyLength;6use polars_core::error::feature_gated;7use polars_core::prelude::row_encode::encode_rows_unordered;8use polars_core::prelude::sort::perfect_sort;9use polars_core::prelude::*;10use polars_core::series::IsSorted;11use polars_core::utils::_split_offsets;12use polars_core::{POOL, downcast_as_macro_arg_physical};13use polars_ops::frame::SeriesJoin;14use polars_ops::frame::join::{ChunkJoinOptIds, private_left_join_multiple_keys};15use polars_ops::prelude::*;16use polars_plan::prelude::*;17use polars_utils::UnitVec;18use polars_utils::sync::SyncPtr;19use polars_utils::vec::PushUnchecked;20use rayon::prelude::*;2122use super::*;2324pub struct WindowExpr {25/// the root column that the Function will be applied on.26/// This will be used to create a smaller DataFrame to prevent taking unneeded columns by index27pub(crate) group_by: Vec<Arc<dyn PhysicalExpr>>,28pub(crate) order_by: Option<(Arc<dyn PhysicalExpr>, SortOptions)>,29pub(crate) apply_columns: Vec<PlSmallStr>,30pub(crate) phys_function: Arc<dyn PhysicalExpr>,31pub(crate) mapping: WindowMapping,32pub(crate) expr: Expr,33pub(crate) has_different_group_sources: bool,34pub(crate) output_field: Field,3536pub(crate) all_group_by_are_elementwise: bool,37pub(crate) order_by_is_elementwise: bool,38}3940#[cfg_attr(debug_assertions, derive(Debug))]41enum MapStrategy {42// Join by key, this the most expensive43// for reduced aggregations44Join,45// explode now46Explode,47// Use an arg_sort to map the values back48Map,49Nothing,50}5152impl WindowExpr {53fn map_list_agg_by_arg_sort(54&self,55out_column: Column,56flattened: &Column,57mut ac: AggregationContext,58gb: GroupBy,59) -> PolarsResult<IdxCa> {60// idx (new-idx, original-idx)61let mut idx_mapping = Vec::with_capacity(out_column.len());6263// we already set this buffer so we can reuse the `original_idx` buffer64// that saves an allocation65let mut take_idx = vec![];6667// groups are not changed, we can map by doing a standard arg_sort.68if std::ptr::eq(ac.groups().as_ref(), gb.get_groups()) {69let mut iter = 0..flattened.len() as IdxSize;70match ac.groups().as_ref().as_ref() {71GroupsType::Idx(groups) => {72for g in groups.all() {73idx_mapping.extend(g.iter().copied().zip(&mut iter));74}75},76GroupsType::Slice { groups, .. } => {77for &[first, len] in groups {78idx_mapping.extend((first..first + len).zip(&mut iter));79}80},81}82}83// groups are changed, we use the new group indexes as arguments of the arg_sort84// and sort by the old indexes85else {86let mut original_idx = Vec::with_capacity(out_column.len());87match gb.get_groups().as_ref() {88GroupsType::Idx(groups) => {89for g in groups.all() {90original_idx.extend_from_slice(g)91}92},93GroupsType::Slice { groups, .. } => {94for &[first, len] in groups {95original_idx.extend(first..first + len)96}97},98};99100let mut original_idx_iter = original_idx.iter().copied();101102match ac.groups().as_ref().as_ref() {103GroupsType::Idx(groups) => {104for g in groups.all() {105idx_mapping.extend(g.iter().copied().zip(&mut original_idx_iter));106}107},108GroupsType::Slice { groups, .. } => {109for &[first, len] in groups {110idx_mapping.extend((first..first + len).zip(&mut original_idx_iter));111}112},113}114original_idx.clear();115take_idx = original_idx;116}117// SAFETY:118// we only have unique indices ranging from 0..len119unsafe { perfect_sort(&idx_mapping, &mut take_idx) };120Ok(IdxCa::from_vec(PlSmallStr::EMPTY, take_idx))121}122123#[allow(clippy::too_many_arguments)]124fn map_by_arg_sort(125&self,126df: &DataFrame,127out_column: Column,128flattened: &Column,129mut ac: AggregationContext,130group_by_columns: &[Column],131gb: GroupBy,132cache_key: String,133state: &ExecutionState,134) -> PolarsResult<Column> {135// we use an arg_sort to map the values back136137// This is a bit more complicated because the final group tuples may differ from the original138// so we use the original indices as idx values to arg_sort the original column139//140// The example below shows the naive version without group tuple mapping141142// columns143// a b a a144//145// agg list146// [0, 2, 3]147// [1]148//149// flatten150//151// [0, 2, 3, 1]152//153// arg_sort154//155// [0, 3, 1, 2]156//157// take by arg_sorted indexes and voila groups mapped158// [0, 1, 2, 3]159160if flattened.len() != df.height() {161let ca = out_column.list().unwrap();162let non_matching_group =163ca.into_iter()164.zip(ac.groups().iter())165.find(|(output, group)| {166if let Some(output) = output {167output.as_ref().len() != group.len()168} else {169false170}171});172173if let Some((output, group)) = non_matching_group {174let first = group.first();175let group = group_by_columns176.iter()177.map(|s| format!("{}", s.get(first as usize).unwrap()))178.collect::<Vec<_>>();179polars_bail!(180expr = self.expr, ShapeMismatch:181"the length of the window expression did not match that of the group\182\n> group: {}\n> group length: {}\n> output: '{:?}'",183comma_delimited(String::new(), &group), group.len(), output.unwrap()184);185} else {186polars_bail!(187expr = self.expr, ShapeMismatch:188"the length of the window expression did not match that of the group"189);190};191}192193let idx = if state.cache_window() {194if let Some(idx) = state.window_cache.get_map(&cache_key) {195idx196} else {197let idx = Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?);198state.window_cache.insert_map(cache_key, idx.clone());199idx200}201} else {202Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?)203};204205// SAFETY:206// groups should always be in bounds.207unsafe { Ok(flattened.take_unchecked(&idx)) }208}209210fn run_aggregation<'a>(211&self,212df: &DataFrame,213state: &ExecutionState,214gb: &'a GroupBy,215) -> PolarsResult<AggregationContext<'a>> {216let ac = self217.phys_function218.evaluate_on_groups(df, gb.get_groups(), state)?;219Ok(ac)220}221222fn is_explicit_list_agg(&self) -> bool {223// col("foo").implode()224// col("foo").implode().alias()225// ..226// col("foo").implode().alias().alias()227//228// but not:229// col("foo").implode().sum().alias()230// ..231// col("foo").min()232let mut explicit_list = false;233for e in &self.expr {234if let Expr::Over { function, .. } = e {235// or list().alias236let mut finishes_list = false;237for e in &**function {238match e {239Expr::Agg(AggExpr::Implode(_)) => {240finishes_list = true;241},242Expr::Alias(_, _) => {},243_ => break,244}245}246explicit_list = finishes_list;247}248}249250explicit_list251}252253fn is_simple_column_expr(&self) -> bool {254// col()255// or col().alias()256let mut simple_col = false;257for e in &self.expr {258if let Expr::Over { function, .. } = e {259// or list().alias260for e in &**function {261match e {262Expr::Column(_) => {263simple_col = true;264},265Expr::Alias(_, _) => {},266_ => break,267}268}269}270}271simple_col272}273274fn is_aggregation(&self) -> bool {275// col()276// or col().agg()277let mut agg_col = false;278for e in &self.expr {279if let Expr::Over { function, .. } = e {280// or list().alias281for e in &**function {282match e {283Expr::Agg(_) => {284agg_col = true;285},286Expr::Alias(_, _) => {},287_ => break,288}289}290}291}292agg_col293}294295fn determine_map_strategy(296&self,297ac: &mut AggregationContext,298gb: &GroupBy,299) -> PolarsResult<MapStrategy> {300match (self.mapping, ac.agg_state()) {301// Explode302// `(col("x").sum() * col("y")).list().over("groups").flatten()`303(WindowMapping::Explode, _) => Ok(MapStrategy::Explode),304// // explicit list305// // `(col("x").sum() * col("y")).list().over("groups")`306// (false, false, _) => Ok(MapStrategy::Join),307// aggregations308//`sum("foo").over("groups")`309(_, AggState::AggregatedScalar(_)) => Ok(MapStrategy::Join),310// no explicit aggregations, map over the groups311//`(col("x").sum() * col("y")).over("groups")`312(WindowMapping::Join, AggState::AggregatedList(_)) => Ok(MapStrategy::Join),313// no explicit aggregations, map over the groups314//`(col("x").sum() * col("y")).over("groups")`315(WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => {316if let GroupsType::Slice { .. } = gb.get_groups().as_ref() {317// Result can be directly exploded if the input was sorted.318ac.groups().as_ref().check_lengths(gb.get_groups())?;319Ok(MapStrategy::Explode)320} else {321Ok(MapStrategy::Map)322}323},324// no aggregations, just return column325// or an aggregation that has been flattened326// we have to check which one327//`col("foo").over("groups")`328(WindowMapping::GroupsToRows, AggState::NotAggregated(_)) => {329// col()330// or col().alias()331if self.is_simple_column_expr() {332Ok(MapStrategy::Nothing)333} else {334Ok(MapStrategy::Map)335}336},337(WindowMapping::Join, AggState::NotAggregated(_)) => Ok(MapStrategy::Join),338// literals, do nothing and let broadcast339(_, AggState::LiteralScalar(_)) => Ok(MapStrategy::Nothing),340}341}342}343344// Utility to create partitions and cache keys345pub fn window_function_format_order_by(to: &mut String, e: &Expr, k: &SortOptions) {346write!(to, "_PL_{:?}{}_{}", e, k.descending, k.nulls_last).unwrap();347}348349impl PhysicalExpr for WindowExpr {350// Note: this was first implemented with expression evaluation but this performed really bad.351// Therefore we choose the group_by -> apply -> self join approach352353// This first cached the group_by and the join tuples, but rayon under a mutex leads to deadlocks:354// https://github.com/rayon-rs/rayon/issues/592355fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {356// This method does the following:357// 1. determine group_by tuples based on the group_column358// 2. apply an aggregation function359// 3. join the results back to the original dataframe360// this stores all group values on the original df size361//362// we have several strategies for this363// - 3.1 JOIN364// Use a join for aggregations like365// `sum("foo").over("groups")`366// and explicit `list` aggregations367// `(col("x").sum() * col("y")).list().over("groups")`368//369// - 3.2 EXPLODE370// Explicit list aggregations that are followed by `over().flatten()`371// # the fastest method to do things over groups when the groups are sorted.372// # note that it will require an explicit `list()` call from now on.373// `(col("x").sum() * col("y")).list().over("groups").flatten()`374//375// - 3.3. MAP to original locations376// This will be done for list aggregations that are not explicitly aggregated as list377// `(col("x").sum() * col("y")).over("groups")378// This can be used to reverse, sort, shuffle etc. the values in a group379380// 4. select the final column and return381382if df.height() == 0 {383let field = self.phys_function.to_field(df.schema())?;384match self.mapping {385WindowMapping::Join => {386return Ok(Column::full_null(387field.name().clone(),3880,389&DataType::List(Box::new(field.dtype().clone())),390));391},392_ => {393return Ok(Column::full_null(field.name().clone(), 0, field.dtype()));394},395}396}397398let mut group_by_columns = self399.group_by400.iter()401.map(|e| e.evaluate(df, state))402.collect::<PolarsResult<Vec<_>>>()?;403404// if the keys are sorted405let sorted_keys = group_by_columns.iter().all(|s| {406matches!(407s.is_sorted_flag(),408IsSorted::Ascending | IsSorted::Descending409)410});411let explicit_list_agg = self.is_explicit_list_agg();412413// if we flatten this column we need to make sure the groups are sorted.414let mut sort_groups = matches!(self.mapping, WindowMapping::Explode) ||415// if not416// `col().over()`417// and not418// `col().list().over`419// and not420// `col().sum()`421// and keys are sorted422// we may optimize with explode call423(!self.is_simple_column_expr() && !explicit_list_agg && sorted_keys && !self.is_aggregation());424425// overwrite sort_groups for some expressions426// TODO: fully understand the rationale is here.427if self.has_different_group_sources {428sort_groups = true429}430431let create_groups = || {432let gb = df.group_by_with_series(group_by_columns.clone(), true, sort_groups)?;433let mut groups = gb.into_groups();434435if let Some((order_by, options)) = &self.order_by {436let order_by = order_by.evaluate(df, state)?;437polars_ensure!(order_by.len() == df.height(), ShapeMismatch: "the order by expression evaluated to a length: {} that doesn't match the input DataFrame: {}", order_by.len(), df.height());438groups = update_groups_sort_by(&groups, order_by.as_materialized_series(), options)?439.into_sliceable()440}441442let out: PolarsResult<GroupPositions> = Ok(groups);443out444};445446// Try to get cached grouptuples447let (mut groups, cache_key) = if state.cache_window() {448let mut cache_key = String::with_capacity(32 * group_by_columns.len());449write!(&mut cache_key, "{}", state.branch_idx).unwrap();450for s in &group_by_columns {451cache_key.push_str(s.name());452}453if let Some((e, options)) = &self.order_by {454let e = match e.as_expression() {455Some(e) => e,456None => {457polars_bail!(InvalidOperation: "cannot order by this expression in window function")458},459};460window_function_format_order_by(&mut cache_key, e, options)461}462463let groups = match state.window_cache.get_groups(&cache_key) {464Some(groups) => groups,465None => create_groups()?,466};467(groups, cache_key)468} else {469(create_groups()?, "".to_string())470};471472// 2. create GroupBy object and apply aggregation473let apply_columns = self.apply_columns.clone();474475// some window expressions need sorted groups476// to make sure that the caches align we sort477// the groups, so that the cached groups and join keys478// are consistent among all windows479if sort_groups || state.cache_window() {480groups.sort();481state482.window_cache483.insert_groups(cache_key.clone(), groups.clone());484}485486// broadcast if required487for col in group_by_columns.iter_mut() {488if col.len() != df.height() {489polars_ensure!(490col.len() == 1,491ShapeMismatch: "columns used as `partition_by` must have the same length as the DataFrame"492);493*col = col.new_from_index(0, df.height())494}495}496497let gb = GroupBy::new(df, group_by_columns.clone(), groups, Some(apply_columns));498499let mut ac = self.run_aggregation(df, state, &gb)?;500501use MapStrategy::*;502503match self.determine_map_strategy(&mut ac, &gb)? {504Nothing => {505let mut out = ac.flat_naive().into_owned();506507if ac.is_literal() {508out = out.new_from_index(0, df.height())509}510Ok(out.into_column())511},512Explode => {513let out = if self.phys_function.is_scalar() {514ac.get_values().clone()515} else {516ac.aggregated().explode(ExplodeOptions {517empty_as_null: true,518keep_nulls: true,519})?520};521Ok(out.into_column())522},523Map => {524// TODO!525// investigate if sorted arrays can be return directly526let out_column = ac.aggregated();527let flattened = out_column.explode(ExplodeOptions {528empty_as_null: true,529keep_nulls: true,530})?;531// we extend the lifetime as we must convince the compiler that ac lives532// long enough. We drop `GrouBy` when we are done with `ac`.533let ac = unsafe {534std::mem::transmute::<AggregationContext<'_>, AggregationContext<'static>>(ac)535};536self.map_by_arg_sort(537df,538out_column,539&flattened,540ac,541&group_by_columns,542gb,543cache_key,544state,545)546},547Join => {548let out_column = ac.aggregated();549// we try to flatten/extend the array by repeating the aggregated value n times550// where n is the number of members in that group. That way we can try to reuse551// the same map by arg_sort logic as done for listed aggregations552let update_groups = !matches!(&ac.update_groups, UpdateGroups::No);553match (554&ac.update_groups,555set_by_groups(&out_column, &ac, df.height(), update_groups),556) {557// for aggregations that reduce like sum, mean, first and are numeric558// we take the group locations to directly map them to the right place559(UpdateGroups::No, Some(out)) => Ok(out.into_column()),560(_, _) => {561let keys = gb.keys();562563let get_join_tuples = || {564if group_by_columns.len() == 1 {565let mut left = group_by_columns[0].clone();566// group key from right column567let mut right = keys[0].clone();568569let (left, right) = if left.dtype().is_nested() {570(571ChunkedArray::<BinaryOffsetType>::with_chunk(572"".into(),573row_encode::_get_rows_encoded_unordered(&[574left.clone()575])?576.into_array(),577)578.into_series(),579ChunkedArray::<BinaryOffsetType>::with_chunk(580"".into(),581row_encode::_get_rows_encoded_unordered(&[582right.clone()583])?584.into_array(),585)586.into_series(),587)588} else {589(590left.into_materialized_series().clone(),591right.into_materialized_series().clone(),592)593};594595PolarsResult::Ok(Arc::new(596left.hash_join_left(&right, JoinValidation::ManyToMany, true)597.unwrap()598.1,599))600} else {601let df_right =602unsafe { DataFrame::new_unchecked_infer_height(keys) };603let df_left = unsafe {604DataFrame::new_unchecked_infer_height(group_by_columns)605};606Ok(Arc::new(607private_left_join_multiple_keys(&df_left, &df_right, true)?.1,608))609}610};611612// try to get cached join_tuples613let join_opt_ids = if state.cache_window() {614if let Some(jt) = state.window_cache.get_join(&cache_key) {615jt616} else {617let jt = get_join_tuples()?;618state.window_cache.insert_join(cache_key, jt.clone());619jt620}621} else {622get_join_tuples()?623};624625let out = materialize_column(&join_opt_ids, &out_column);626Ok(out.into_column())627},628}629},630}631}632633fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {634Ok(self.output_field.clone())635}636637fn is_scalar(&self) -> bool {638false639}640641#[allow(clippy::ptr_arg)]642fn evaluate_on_groups<'a>(643&self,644df: &DataFrame,645groups: &'a GroupPositions,646state: &ExecutionState,647) -> PolarsResult<AggregationContext<'a>> {648if self.group_by.is_empty()649|| !self.all_group_by_are_elementwise650|| (self.order_by.is_some() && !self.order_by_is_elementwise)651{652polars_bail!(653InvalidOperation:654"window expression with non-elementwise `partition_by` or `order_by` not allowed in aggregation context"655);656}657658let length_preserving_height = if let Some((c, _)) = state.element.as_ref() {659c.len()660} else {661df.height()662};663664let function_is_scalar = self.phys_function.is_scalar();665let needs_remap_to_rows =666matches!(self.mapping, WindowMapping::GroupsToRows) && !function_is_scalar;667668let partition_by_columns = self669.group_by670.iter()671.map(|e| {672let mut e = e.evaluate(df, state)?;673if e.len() == 1 {674e = e.new_from_index(0, length_preserving_height);675}676// Sanity check: Length Preserving.677assert_eq!(e.len(), length_preserving_height,);678Ok(e)679})680.collect::<PolarsResult<Vec<_>>>()?;681let order_by = match &self.order_by {682None => None,683Some((e, options)) => {684let mut e = e.evaluate(df, state)?;685if e.len() == 1 {686e = e.new_from_index(0, length_preserving_height);687}688// Sanity check: Length Preserving.689assert_eq!(e.len(), length_preserving_height);690let arr: Option<PrimitiveArray<IdxSize>> = if needs_remap_to_rows {691feature_gated!("rank", {692// Performance: precompute the rank here, so we can avoid dispatching per group693// later.694use polars_ops::series::SeriesRank;695let arr = e.as_materialized_series().rank(696RankOptions {697method: RankMethod::Ordinal,698descending: false,699},700None,701);702let arr = arr.idx()?;703let arr = arr.rechunk();704Some(arr.downcast_as_array().clone())705})706} else {707None708};709710Some((e.clone(), arr, *options))711},712};713714let (num_unique_ids, unique_ids) = if partition_by_columns.len() == 1 {715partition_by_columns[0].unique_id()?716} else {717ChunkUnique::unique_id(&encode_rows_unordered(&partition_by_columns)?)?718};719720// All the groups within the existing groups.721let subgroups_approx_capacity = groups.len();722let mut subgroups: Vec<(IdxSize, UnitVec<IdxSize>)> =723Vec::with_capacity(subgroups_approx_capacity);724725// Indices for the output groups. Not used with `WindowMapping::Explode`.726let mut gather_indices_offset = 0;727let mut gather_indices: Vec<(IdxSize, UnitVec<IdxSize>)> =728Vec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {7290730} else {731groups.len()732});733// Slices for the output groups. Only used with `WindowMapping::Explode`.734let mut strategy_explode_groups: Vec<[IdxSize; 2]> =735Vec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {736groups.len()737} else {7380739});740741// Amortized vectors to reorder based on `order_by`.742let mut amort_arg_sort = Vec::new();743let mut amort_offsets = Vec::new();744745// Amortized vectors to gather per group data.746let mut amort_subgroups_order = Vec::with_capacity(num_unique_ids as usize);747let mut amort_subgroups_sizes = Vec::with_capacity(num_unique_ids as usize);748let mut amort_subgroups_indices = (0..num_unique_ids)749.map(|_| (0, UnitVec::new()))750.collect::<Vec<(IdxSize, UnitVec<IdxSize>)>>();751752macro_rules! map_window_groups {753($iter:expr, $get:expr) => {754let mut subgroup_gather_indices =755UnitVec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {7560757} else {758$iter.len()759});760761amort_subgroups_order.clear();762amort_subgroups_sizes.clear();763amort_subgroups_sizes.resize(num_unique_ids as usize, 0);764765// Determine sizes per subgroup.766for i in $iter.clone() {767let id = *unsafe { unique_ids.get_unchecked(i as usize) };768let size = unsafe { amort_subgroups_sizes.get_unchecked_mut(id as usize) };769if *size == 0 {770unsafe { amort_subgroups_order.push_unchecked(id) };771}772*size += 1;773}774775if matches!(self.mapping, WindowMapping::Explode) {776strategy_explode_groups.push([777subgroups.len() as IdxSize,778amort_subgroups_order.len() as IdxSize,779]);780}781782// Set starting gather indices and reserve capacity per subgroup.783let mut offset = if needs_remap_to_rows {784gather_indices_offset785} else {786subgroups.len() as IdxSize787};788for &id in &amort_subgroups_order {789let size = *unsafe { amort_subgroups_sizes.get_unchecked(id as usize) };790let (next_gather_idx, indices) =791unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };792indices.reserve(size as usize);793*next_gather_idx = offset;794offset += if needs_remap_to_rows { size } else { 1 };795}796797// Collect gather indices.798if matches!(self.mapping, WindowMapping::Explode) {799for i in $iter {800let id = *unsafe { unique_ids.get_unchecked(i as usize) };801let (_, indices) =802unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };803unsafe { indices.push_unchecked(i) };804}805} else {806// If we are remapping exploded rows back to rows and are reordering, we need807// to ensure we reorder the gather indices as well. Reordering the `subgroup`808// indices is done later.809//810// We having precalculated both the `unique_ids` and `order_by_ranks` in811// efficient kernels, we can now relatively efficient arg_sort per group. This812// is still horrendously slow, but at least not as bad as it would be if you813// did this naively.814if needs_remap_to_rows && let Some((_, arr, options)) = &order_by {815let arr = arr.as_ref().unwrap();816amort_arg_sort.clear();817amort_arg_sort.extend(0..$iter.len() as IdxSize);818match arr.validity() {819None => {820let arr = arr.values().as_slice();821amort_arg_sort.sort_by(|a, b| {822let in_group_idx_a = $get(*a as usize) as usize;823let in_group_idx_b = $get(*b as usize) as usize;824825let order_a = unsafe { arr.get_unchecked(in_group_idx_a) };826let order_b = unsafe { arr.get_unchecked(in_group_idx_b) };827828let mut cmp = order_a.cmp(&order_b);829// Performance: This can generally be handled branchlessly.830if options.descending {831cmp = cmp.reverse();832}833cmp834});835},836Some(validity) => {837let arr = arr.values().as_slice();838amort_arg_sort.sort_by(|a, b| {839let in_group_idx_a = $get(*a as usize) as usize;840let in_group_idx_b = $get(*b as usize) as usize;841842let is_valid_a =843unsafe { validity.get_bit_unchecked(in_group_idx_a) };844let is_valid_b =845unsafe { validity.get_bit_unchecked(in_group_idx_b) };846let order_a = unsafe { arr.get_unchecked(in_group_idx_a) };847let order_b = unsafe { arr.get_unchecked(in_group_idx_b) };848849if !is_valid_a & !is_valid_b {850return Ordering::Equal;851}852853let mut cmp = order_a.cmp(&order_b);854if !is_valid_a {855cmp = Ordering::Less;856}857if !is_valid_b {858cmp = Ordering::Greater;859}860if options.descending861| ((!is_valid_a | !is_valid_b) & options.nulls_last)862{863cmp = cmp.reverse();864}865cmp866});867},868}869870amort_offsets.clear();871amort_offsets.resize($iter.len(), 0);872for &id in &amort_subgroups_order {873amort_subgroups_sizes[id as usize] = 0;874}875876for &idx in &amort_arg_sort {877let in_group_idx = $get(idx as usize);878let id = *unsafe { unique_ids.get_unchecked(in_group_idx as usize) };879amort_offsets[idx as usize] = amort_subgroups_sizes[id as usize];880amort_subgroups_sizes[id as usize] += 1;881}882883for (i, offset) in $iter.zip(&amort_offsets) {884let id = *unsafe { unique_ids.get_unchecked(i as usize) };885let (next_gather_idx, indices) =886unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };887unsafe {888subgroup_gather_indices.push_unchecked(*next_gather_idx + *offset)889};890unsafe { indices.push_unchecked(i) };891}892} else {893for i in $iter {894let id = *unsafe { unique_ids.get_unchecked(i as usize) };895let (next_gather_idx, indices) =896unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };897unsafe { subgroup_gather_indices.push_unchecked(*next_gather_idx) };898*next_gather_idx += IdxSize::from(needs_remap_to_rows);899unsafe { indices.push_unchecked(i) };900}901}902}903904// Push groups into nested_groups.905subgroups.extend(amort_subgroups_order.iter().map(|&id| {906let (_, indices) =907unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };908let indices = std::mem::take(indices);909(*unsafe { indices.get_unchecked(0) }, indices)910}));911912if !matches!(self.mapping, WindowMapping::Explode) {913gather_indices_offset += subgroup_gather_indices.len() as IdxSize;914gather_indices.push((915subgroup_gather_indices.first().copied().unwrap_or(0),916subgroup_gather_indices,917));918}919};920}921match groups.as_ref() {922GroupsType::Idx(idxs) => {923for g in idxs.all() {924map_window_groups!(g.iter().copied(), (|i: usize| g[i]));925}926},927GroupsType::Slice {928groups,929overlapping: _,930monotonic: _,931} => {932for [s, l] in groups.iter() {933let s = *s;934let l = *l;935let iter = unsafe { TrustMyLength::new(s..s + l, l as usize) };936map_window_groups!(iter, (|i: usize| s + i as IdxSize));937}938},939}940941let mut subgroups = GroupsType::Idx(subgroups.into());942if let Some((order_by, _, options)) = order_by {943subgroups =944update_groups_sort_by(&subgroups, order_by.as_materialized_series(), &options)?;945}946let subgroups = subgroups.into_sliceable();947let mut data = self948.phys_function949.evaluate_on_groups(df, &subgroups, state)?950.finalize();951952let final_groups = if matches!(self.mapping, WindowMapping::Explode) {953if !function_is_scalar {954let (data_s, offsets) = data.list()?.explode_and_offsets(ExplodeOptions {955empty_as_null: false,956keep_nulls: false,957})?;958data = data_s.into_column();959960let mut exploded_offset = 0;961for [start, length] in strategy_explode_groups.iter_mut() {962let exploded_start = exploded_offset;963let exploded_length = offsets964.lengths()965.skip(*start as usize)966.take(*length as usize)967.sum::<usize>() as IdxSize;968exploded_offset += exploded_length;969*start = exploded_start;970*length = exploded_length;971}972}973GroupsType::new_slice(strategy_explode_groups, false, true)974} else {975if needs_remap_to_rows {976let data_l = data.list()?;977assert_eq!(data_l.len(), subgroups.len());978let lengths = data_l.lst_lengths();979let length_mismatch = match subgroups.as_ref() {980GroupsType::Idx(idx) => idx981.all()982.iter()983.zip(&lengths)984.any(|(i, l)| i.len() as IdxSize != l.unwrap()),985GroupsType::Slice {986groups,987overlapping: _,988monotonic: _,989} => groups990.iter()991.zip(&lengths)992.any(|([_, i], l)| *i != l.unwrap()),993};994995polars_ensure!(996!length_mismatch,997expr = self.expr, ShapeMismatch:998"the length of the window expression did not match that of the group"999);10001001data = data_l1002.explode(ExplodeOptions {1003empty_as_null: false,1004keep_nulls: true,1005})?1006.into_column();1007}1008GroupsType::Idx(gather_indices.into())1009}1010.into_sliceable();10111012Ok(AggregationContext {1013state: AggState::NotAggregated(data),1014groups: Cow::Owned(final_groups),1015update_groups: UpdateGroups::No,1016original_len: false,1017})1018}10191020fn as_expression(&self) -> Option<&Expr> {1021Some(&self.expr)1022}1023}10241025fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Column) -> Column {1026{1027use arrow::Either;1028use polars_ops::chunked_array::TakeChunked;10291030match join_opt_ids {1031Either::Left(ids) => unsafe {1032IdxCa::with_nullable_idx(ids, |idx| out_column.take_unchecked(idx))1033},1034Either::Right(ids) => unsafe { out_column.take_opt_chunked_unchecked(ids, false) },1035}1036}1037}10381039/// Simple reducing aggregation can be set by the groups1040fn set_by_groups(1041s: &Column,1042ac: &AggregationContext,1043len: usize,1044update_groups: bool,1045) -> Option<Column> {1046if update_groups || !ac.original_len {1047return None;1048}1049if s.dtype().to_physical().is_primitive_numeric() {1050let dtype = s.dtype();1051let s = s.to_physical_repr();10521053macro_rules! dispatch {1054($ca:expr) => {{ Some(set_numeric($ca, &ac.groups, len)) }};1055}1056downcast_as_macro_arg_physical!(&s, dispatch)1057.map(|s| unsafe { s.from_physical_unchecked(dtype) }.unwrap())1058.map(Column::from)1059} else {1060None1061}1062}10631064fn set_numeric<T: PolarsNumericType>(1065ca: &ChunkedArray<T>,1066groups: &GroupsType,1067len: usize,1068) -> Series {1069let mut values = Vec::with_capacity(len);1070let ptr: *mut T::Native = values.as_mut_ptr();1071// SAFETY:1072// we will write from different threads but we will never alias.1073let sync_ptr_values = unsafe { SyncPtr::new(ptr) };10741075if ca.null_count() == 0 {1076let ca = ca.rechunk();1077match groups {1078GroupsType::Idx(groups) => {1079let agg_vals = ca.cont_slice().expect("rechunked");1080POOL.install(|| {1081agg_vals1082.par_iter()1083.zip(groups.all().par_iter())1084.for_each(|(v, g)| {1085let ptr = sync_ptr_values.get();1086for idx in g.as_slice() {1087debug_assert!((*idx as usize) < len);1088unsafe { *ptr.add(*idx as usize) = *v }1089}1090})1091})1092},1093GroupsType::Slice { groups, .. } => {1094let agg_vals = ca.cont_slice().expect("rechunked");1095POOL.install(|| {1096agg_vals1097.par_iter()1098.zip(groups.par_iter())1099.for_each(|(v, [start, g_len])| {1100let ptr = sync_ptr_values.get();1101let start = *start as usize;1102let end = start + *g_len as usize;1103for idx in start..end {1104debug_assert!(idx < len);1105unsafe { *ptr.add(idx) = *v }1106}1107})1108});1109},1110}11111112// SAFETY: we have written all slots1113unsafe { values.set_len(len) }1114ChunkedArray::<T>::new_vec(ca.name().clone(), values).into_series()1115} else {1116// We don't use a mutable bitmap as bits will have race conditions!1117// A single byte might alias if we write from single threads.1118let mut validity: Vec<bool> = vec![false; len];1119let validity_ptr = validity.as_mut_ptr();1120let sync_ptr_validity = unsafe { SyncPtr::new(validity_ptr) };11211122let n_threads = POOL.current_num_threads();1123let offsets = _split_offsets(ca.len(), n_threads);11241125match groups {1126GroupsType::Idx(groups) => offsets.par_iter().for_each(|(offset, offset_len)| {1127let offset = *offset;1128let offset_len = *offset_len;1129let ca = ca.slice(offset as i64, offset_len);1130let groups = &groups.all()[offset..offset + offset_len];1131let values_ptr = sync_ptr_values.get();1132let validity_ptr = sync_ptr_validity.get();11331134ca.iter().zip(groups.iter()).for_each(|(opt_v, g)| {1135for idx in g.as_slice() {1136let idx = *idx as usize;1137debug_assert!(idx < len);1138unsafe {1139match opt_v {1140Some(v) => {1141*values_ptr.add(idx) = v;1142*validity_ptr.add(idx) = true;1143},1144None => {1145*values_ptr.add(idx) = T::Native::default();1146*validity_ptr.add(idx) = false;1147},1148};1149}1150}1151})1152}),1153GroupsType::Slice { groups, .. } => {1154offsets.par_iter().for_each(|(offset, offset_len)| {1155let offset = *offset;1156let offset_len = *offset_len;1157let ca = ca.slice(offset as i64, offset_len);1158let groups = &groups[offset..offset + offset_len];1159let values_ptr = sync_ptr_values.get();1160let validity_ptr = sync_ptr_validity.get();11611162for (opt_v, [start, g_len]) in ca.iter().zip(groups.iter()) {1163let start = *start as usize;1164let end = start + *g_len as usize;1165for idx in start..end {1166debug_assert!(idx < len);1167unsafe {1168match opt_v {1169Some(v) => {1170*values_ptr.add(idx) = v;1171*validity_ptr.add(idx) = true;1172},1173None => {1174*values_ptr.add(idx) = T::Native::default();1175*validity_ptr.add(idx) = false;1176},1177};1178}1179}1180}1181})1182},1183}1184// SAFETY: we have written all slots1185unsafe { values.set_len(len) }1186let validity = Bitmap::from(validity);1187let arr = PrimitiveArray::new(1188T::get_static_dtype()1189.to_physical()1190.to_arrow(CompatLevel::newest()),1191values.into(),1192Some(validity),1193);1194Series::try_from((ca.name().clone(), arr.boxed())).unwrap()1195}1196}119711981199