Path: blob/main/crates/polars-expr/src/expressions/window.rs
6940 views
use std::fmt::Write;12use arrow::array::PrimitiveArray;3use arrow::bitmap::Bitmap;4use polars_core::prelude::*;5use polars_core::series::IsSorted;6use polars_core::utils::_split_offsets;7use polars_core::{POOL, downcast_as_macro_arg_physical};8use polars_ops::frame::SeriesJoin;9use polars_ops::frame::join::{ChunkJoinOptIds, private_left_join_multiple_keys};10use polars_ops::prelude::*;11use polars_plan::prelude::*;12use polars_utils::sort::perfect_sort;13use polars_utils::sync::SyncPtr;14use rayon::prelude::*;1516use super::*;1718pub struct WindowExpr {19/// the root column that the Function will be applied on.20/// This will be used to create a smaller DataFrame to prevent taking unneeded columns by index21pub(crate) group_by: Vec<Arc<dyn PhysicalExpr>>,22pub(crate) order_by: Option<(Arc<dyn PhysicalExpr>, SortOptions)>,23pub(crate) apply_columns: Vec<PlSmallStr>,24/// A function Expr. i.e. Mean, Median, Max, etc.25pub(crate) function: Expr,26pub(crate) phys_function: Arc<dyn PhysicalExpr>,27pub(crate) mapping: WindowMapping,28pub(crate) expr: Expr,29pub(crate) has_different_group_sources: bool,30}3132#[cfg_attr(debug_assertions, derive(Debug))]33enum MapStrategy {34// Join by key, this the most expensive35// for reduced aggregations36Join,37// explode now38Explode,39// Use an arg_sort to map the values back40Map,41Nothing,42}4344impl WindowExpr {45fn map_list_agg_by_arg_sort(46&self,47out_column: Column,48flattened: &Column,49mut ac: AggregationContext,50gb: GroupBy,51) -> PolarsResult<IdxCa> {52// idx (new-idx, original-idx)53let mut idx_mapping = Vec::with_capacity(out_column.len());5455// we already set this buffer so we can reuse the `original_idx` buffer56// that saves an allocation57let mut take_idx = vec![];5859// groups are not changed, we can map by doing a standard arg_sort.60if std::ptr::eq(ac.groups().as_ref(), gb.get_groups()) {61let mut iter = 0..flattened.len() as IdxSize;62match ac.groups().as_ref().as_ref() {63GroupsType::Idx(groups) => {64for g in groups.all() {65idx_mapping.extend(g.iter().copied().zip(&mut iter));66}67},68GroupsType::Slice { groups, .. } => {69for &[first, len] in groups {70idx_mapping.extend((first..first + len).zip(&mut iter));71}72},73}74}75// groups are changed, we use the new group indexes as arguments of the arg_sort76// and sort by the old indexes77else {78let mut original_idx = Vec::with_capacity(out_column.len());79match gb.get_groups().as_ref() {80GroupsType::Idx(groups) => {81for g in groups.all() {82original_idx.extend_from_slice(g)83}84},85GroupsType::Slice { groups, .. } => {86for &[first, len] in groups {87original_idx.extend(first..first + len)88}89},90};9192let mut original_idx_iter = original_idx.iter().copied();9394match ac.groups().as_ref().as_ref() {95GroupsType::Idx(groups) => {96for g in groups.all() {97idx_mapping.extend(g.iter().copied().zip(&mut original_idx_iter));98}99},100GroupsType::Slice { groups, .. } => {101for &[first, len] in groups {102idx_mapping.extend((first..first + len).zip(&mut original_idx_iter));103}104},105}106original_idx.clear();107take_idx = original_idx;108}109// SAFETY:110// we only have unique indices ranging from 0..len111unsafe { perfect_sort(&POOL, &idx_mapping, &mut take_idx) };112Ok(IdxCa::from_vec(PlSmallStr::EMPTY, take_idx))113}114115#[allow(clippy::too_many_arguments)]116fn map_by_arg_sort(117&self,118df: &DataFrame,119out_column: Column,120flattened: &Column,121mut ac: AggregationContext,122group_by_columns: &[Column],123gb: GroupBy,124cache_key: String,125state: &ExecutionState,126) -> PolarsResult<Column> {127// we use an arg_sort to map the values back128129// This is a bit more complicated because the final group tuples may differ from the original130// so we use the original indices as idx values to arg_sort the original column131//132// The example below shows the naive version without group tuple mapping133134// columns135// a b a a136//137// agg list138// [0, 2, 3]139// [1]140//141// flatten142//143// [0, 2, 3, 1]144//145// arg_sort146//147// [0, 3, 1, 2]148//149// take by arg_sorted indexes and voila groups mapped150// [0, 1, 2, 3]151152if flattened.len() != df.height() {153let ca = out_column.list().unwrap();154let non_matching_group =155ca.into_iter()156.zip(ac.groups().iter())157.find(|(output, group)| {158if let Some(output) = output {159output.as_ref().len() != group.len()160} else {161false162}163});164165if let Some((output, group)) = non_matching_group {166let first = group.first();167let group = group_by_columns168.iter()169.map(|s| format!("{}", s.get(first as usize).unwrap()))170.collect::<Vec<_>>();171polars_bail!(172expr = self.expr, ShapeMismatch:173"the length of the window expression did not match that of the group\174\n> group: {}\n> group length: {}\n> output: '{:?}'",175comma_delimited(String::new(), &group), group.len(), output.unwrap()176);177} else {178polars_bail!(179expr = self.expr, ShapeMismatch:180"the length of the window expression did not match that of the group"181);182};183}184185let idx = if state.cache_window() {186if let Some(idx) = state.window_cache.get_map(&cache_key) {187idx188} else {189let idx = Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?);190state.window_cache.insert_map(cache_key, idx.clone());191idx192}193} else {194Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?)195};196197// SAFETY:198// groups should always be in bounds.199unsafe { Ok(flattened.take_unchecked(&idx)) }200}201202fn run_aggregation<'a>(203&self,204df: &DataFrame,205state: &ExecutionState,206gb: &'a GroupBy,207) -> PolarsResult<AggregationContext<'a>> {208let ac = self209.phys_function210.evaluate_on_groups(df, gb.get_groups(), state)?;211Ok(ac)212}213214fn is_explicit_list_agg(&self) -> bool {215// col("foo").implode()216// col("foo").implode().alias()217// ..218// col("foo").implode().alias().alias()219//220// but not:221// col("foo").implode().sum().alias()222// ..223// col("foo").min()224let mut explicit_list = false;225for e in &self.expr {226if let Expr::Window { function, .. } = e {227// or list().alias228let mut finishes_list = false;229for e in &**function {230match e {231Expr::Agg(AggExpr::Implode(_)) => {232finishes_list = true;233},234Expr::Alias(_, _) => {},235_ => break,236}237}238explicit_list = finishes_list;239}240}241242explicit_list243}244245fn is_simple_column_expr(&self) -> bool {246// col()247// or col().alias()248let mut simple_col = false;249for e in &self.expr {250if let Expr::Window { function, .. } = e {251// or list().alias252for e in &**function {253match e {254Expr::Column(_) => {255simple_col = true;256},257Expr::Alias(_, _) => {},258_ => break,259}260}261}262}263simple_col264}265266fn is_aggregation(&self) -> bool {267// col()268// or col().agg()269let mut agg_col = false;270for e in &self.expr {271if let Expr::Window { function, .. } = e {272// or list().alias273for e in &**function {274match e {275Expr::Agg(_) => {276agg_col = true;277},278Expr::Alias(_, _) => {},279_ => break,280}281}282}283}284agg_col285}286287fn determine_map_strategy(288&self,289agg_state: &AggState,290gb: &GroupBy,291) -> PolarsResult<MapStrategy> {292match (self.mapping, agg_state) {293// Explode294// `(col("x").sum() * col("y")).list().over("groups").flatten()`295(WindowMapping::Explode, _) => Ok(MapStrategy::Explode),296// // explicit list297// // `(col("x").sum() * col("y")).list().over("groups")`298// (false, false, _) => Ok(MapStrategy::Join),299// aggregations300//`sum("foo").over("groups")`301(_, AggState::AggregatedScalar(_)) => Ok(MapStrategy::Join),302// no explicit aggregations, map over the groups303//`(col("x").sum() * col("y")).over("groups")`304(WindowMapping::Join, AggState::AggregatedList(_)) => Ok(MapStrategy::Join),305// no explicit aggregations, map over the groups306//`(col("x").sum() * col("y")).over("groups")`307(WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => {308if let GroupsType::Slice { .. } = gb.get_groups().as_ref() {309// Result can be directly exploded if the input was sorted.310Ok(MapStrategy::Explode)311} else {312Ok(MapStrategy::Map)313}314},315// no aggregations, just return column316// or an aggregation that has been flattened317// we have to check which one318//`col("foo").over("groups")`319(WindowMapping::GroupsToRows, AggState::NotAggregated(_)) => {320// col()321// or col().alias()322if self.is_simple_column_expr() {323Ok(MapStrategy::Nothing)324} else {325Ok(MapStrategy::Map)326}327},328(WindowMapping::Join, AggState::NotAggregated(_)) => Ok(MapStrategy::Join),329// literals, do nothing and let broadcast330(_, AggState::LiteralScalar(_)) => Ok(MapStrategy::Nothing),331}332}333}334335// Utility to create partitions and cache keys336pub fn window_function_format_order_by(to: &mut String, e: &Expr, k: &SortOptions) {337write!(to, "_PL_{:?}{}_{}", e, k.descending, k.nulls_last).unwrap();338}339340impl PhysicalExpr for WindowExpr {341// Note: this was first implemented with expression evaluation but this performed really bad.342// Therefore we choose the group_by -> apply -> self join approach343344// This first cached the group_by and the join tuples, but rayon under a mutex leads to deadlocks:345// https://github.com/rayon-rs/rayon/issues/592346fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {347// This method does the following:348// 1. determine group_by tuples based on the group_column349// 2. apply an aggregation function350// 3. join the results back to the original dataframe351// this stores all group values on the original df size352//353// we have several strategies for this354// - 3.1 JOIN355// Use a join for aggregations like356// `sum("foo").over("groups")`357// and explicit `list` aggregations358// `(col("x").sum() * col("y")).list().over("groups")`359//360// - 3.2 EXPLODE361// Explicit list aggregations that are followed by `over().flatten()`362// # the fastest method to do things over groups when the groups are sorted.363// # note that it will require an explicit `list()` call from now on.364// `(col("x").sum() * col("y")).list().over("groups").flatten()`365//366// - 3.3. MAP to original locations367// This will be done for list aggregations that are not explicitly aggregated as list368// `(col("x").sum() * col("y")).over("groups")369// This can be used to reverse, sort, shuffle etc. the values in a group370371// 4. select the final column and return372373if df.is_empty() {374let field = self.phys_function.to_field(df.schema())?;375match self.mapping {376WindowMapping::Join => {377return Ok(Column::full_null(378field.name().clone(),3790,380&DataType::List(Box::new(field.dtype().clone())),381));382},383_ => {384return Ok(Column::full_null(field.name().clone(), 0, field.dtype()));385},386}387}388389let group_by_columns = self390.group_by391.iter()392.map(|e| e.evaluate(df, state))393.collect::<PolarsResult<Vec<_>>>()?;394395// if the keys are sorted396let sorted_keys = group_by_columns.iter().all(|s| {397matches!(398s.is_sorted_flag(),399IsSorted::Ascending | IsSorted::Descending400)401});402let explicit_list_agg = self.is_explicit_list_agg();403404// if we flatten this column we need to make sure the groups are sorted.405let mut sort_groups = matches!(self.mapping, WindowMapping::Explode) ||406// if not407// `col().over()`408// and not409// `col().list().over`410// and not411// `col().sum()`412// and keys are sorted413// we may optimize with explode call414(!self.is_simple_column_expr() && !explicit_list_agg && sorted_keys && !self.is_aggregation());415416// overwrite sort_groups for some expressions417// TODO: fully understand the rationale is here.418if self.has_different_group_sources {419sort_groups = true420}421422let create_groups = || {423let gb = df.group_by_with_series(group_by_columns.clone(), true, sort_groups)?;424let mut groups = gb.take_groups();425426if let Some((order_by, options)) = &self.order_by {427let order_by = order_by.evaluate(df, state)?;428polars_ensure!(order_by.len() == df.height(), ShapeMismatch: "the order by expression evaluated to a length: {} that doesn't match the input DataFrame: {}", order_by.len(), df.height());429groups = update_groups_sort_by(&groups, order_by.as_materialized_series(), options)?430.into_sliceable()431}432433let out: PolarsResult<GroupPositions> = Ok(groups);434out435};436437// Try to get cached grouptuples438let (mut groups, cache_key) = if state.cache_window() {439let mut cache_key = String::with_capacity(32 * group_by_columns.len());440write!(&mut cache_key, "{}", state.branch_idx).unwrap();441for s in &group_by_columns {442cache_key.push_str(s.name());443}444if let Some((e, options)) = &self.order_by {445let e = match e.as_expression() {446Some(e) => e,447None => {448polars_bail!(InvalidOperation: "cannot order by this expression in window function")449},450};451window_function_format_order_by(&mut cache_key, e, options)452}453454let groups = match state.window_cache.get_groups(&cache_key) {455Some(groups) => groups,456None => create_groups()?,457};458(groups, cache_key)459} else {460(create_groups()?, "".to_string())461};462463// 2. create GroupBy object and apply aggregation464let apply_columns = self.apply_columns.clone();465466// some window expressions need sorted groups467// to make sure that the caches align we sort468// the groups, so that the cached groups and join keys469// are consistent among all windows470if sort_groups || state.cache_window() {471groups.sort();472state473.window_cache474.insert_groups(cache_key.clone(), groups.clone());475}476let gb = GroupBy::new(df, group_by_columns.clone(), groups, Some(apply_columns));477478let mut ac = self.run_aggregation(df, state, &gb)?;479480use MapStrategy::*;481match self.determine_map_strategy(ac.agg_state(), &gb)? {482Nothing => {483let mut out = ac.flat_naive().into_owned();484485if ac.is_literal() {486out = out.new_from_index(0, df.height())487}488Ok(out.into_column())489},490Explode => {491let out = ac.aggregated().explode(false)?;492Ok(out.into_column())493},494Map => {495// TODO!496// investigate if sorted arrays can be return directly497let out_column = ac.aggregated();498let flattened = out_column.explode(false)?;499// we extend the lifetime as we must convince the compiler that ac lives500// long enough. We drop `GrouBy` when we are done with `ac`.501let ac = unsafe {502std::mem::transmute::<AggregationContext<'_>, AggregationContext<'static>>(ac)503};504self.map_by_arg_sort(505df,506out_column,507&flattened,508ac,509&group_by_columns,510gb,511cache_key,512state,513)514},515Join => {516let out_column = ac.aggregated();517// we try to flatten/extend the array by repeating the aggregated value n times518// where n is the number of members in that group. That way we can try to reuse519// the same map by arg_sort logic as done for listed aggregations520let update_groups = !matches!(&ac.update_groups, UpdateGroups::No);521match (522&ac.update_groups,523set_by_groups(&out_column, &ac, df.height(), update_groups),524) {525// for aggregations that reduce like sum, mean, first and are numeric526// we take the group locations to directly map them to the right place527(UpdateGroups::No, Some(out)) => Ok(out.into_column()),528(_, _) => {529let keys = gb.keys();530531let get_join_tuples = || {532if group_by_columns.len() == 1 {533let mut left = group_by_columns[0].clone();534// group key from right column535let mut right = keys[0].clone();536537let (left, right) = if left.dtype().is_nested() {538(539ChunkedArray::<BinaryOffsetType>::with_chunk(540"".into(),541row_encode::_get_rows_encoded_unordered(&[542left.clone()543])?544.into_array(),545)546.into_series(),547ChunkedArray::<BinaryOffsetType>::with_chunk(548"".into(),549row_encode::_get_rows_encoded_unordered(&[550right.clone()551])?552.into_array(),553)554.into_series(),555)556} else {557(558left.into_materialized_series().clone(),559right.into_materialized_series().clone(),560)561};562563PolarsResult::Ok(Arc::new(564left.hash_join_left(&right, JoinValidation::ManyToMany, true)565.unwrap()566.1,567))568} else {569let df_right =570unsafe { DataFrame::new_no_checks_height_from_first(keys) };571let df_left = unsafe {572DataFrame::new_no_checks_height_from_first(group_by_columns)573};574Ok(Arc::new(575private_left_join_multiple_keys(&df_left, &df_right, true)?.1,576))577}578};579580// try to get cached join_tuples581let join_opt_ids = if state.cache_window() {582if let Some(jt) = state.window_cache.get_join(&cache_key) {583jt584} else {585let jt = get_join_tuples()?;586state.window_cache.insert_join(cache_key, jt.clone());587jt588}589} else {590get_join_tuples()?591};592593let out = materialize_column(&join_opt_ids, &out_column);594Ok(out.into_column())595},596}597},598}599}600601fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {602self.function.to_field(input_schema)603}604605fn is_scalar(&self) -> bool {606false607}608609#[allow(clippy::ptr_arg)]610fn evaluate_on_groups<'a>(611&self,612_df: &DataFrame,613_groups: &'a GroupPositions,614_state: &ExecutionState,615) -> PolarsResult<AggregationContext<'a>> {616polars_bail!(InvalidOperation: "window expression not allowed in aggregation");617}618619fn as_expression(&self) -> Option<&Expr> {620Some(&self.expr)621}622}623624fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Column) -> Column {625{626use arrow::Either;627use polars_ops::chunked_array::TakeChunked;628629match join_opt_ids {630Either::Left(ids) => unsafe {631IdxCa::with_nullable_idx(ids, |idx| out_column.take_unchecked(idx))632},633Either::Right(ids) => unsafe { out_column.take_opt_chunked_unchecked(ids, false) },634}635}636}637638/// Simple reducing aggregation can be set by the groups639fn set_by_groups(640s: &Column,641ac: &AggregationContext,642len: usize,643update_groups: bool,644) -> Option<Column> {645if update_groups || !ac.original_len {646return None;647}648if s.dtype().to_physical().is_primitive_numeric() {649let dtype = s.dtype();650let s = s.to_physical_repr();651652macro_rules! dispatch {653($ca:expr) => {{ Some(set_numeric($ca, &ac.groups, len)) }};654}655downcast_as_macro_arg_physical!(&s, dispatch)656.map(|s| unsafe { s.from_physical_unchecked(dtype) }.unwrap())657.map(Column::from)658} else {659None660}661}662663fn set_numeric<T: PolarsNumericType>(664ca: &ChunkedArray<T>,665groups: &GroupsType,666len: usize,667) -> Series {668let mut values = Vec::with_capacity(len);669let ptr: *mut T::Native = values.as_mut_ptr();670// SAFETY:671// we will write from different threads but we will never alias.672let sync_ptr_values = unsafe { SyncPtr::new(ptr) };673674if ca.null_count() == 0 {675let ca = ca.rechunk();676match groups {677GroupsType::Idx(groups) => {678let agg_vals = ca.cont_slice().expect("rechunked");679POOL.install(|| {680agg_vals681.par_iter()682.zip(groups.all().par_iter())683.for_each(|(v, g)| {684let ptr = sync_ptr_values.get();685for idx in g.as_slice() {686debug_assert!((*idx as usize) < len);687unsafe { *ptr.add(*idx as usize) = *v }688}689})690})691},692GroupsType::Slice { groups, .. } => {693let agg_vals = ca.cont_slice().expect("rechunked");694POOL.install(|| {695agg_vals696.par_iter()697.zip(groups.par_iter())698.for_each(|(v, [start, g_len])| {699let ptr = sync_ptr_values.get();700let start = *start as usize;701let end = start + *g_len as usize;702for idx in start..end {703debug_assert!(idx < len);704unsafe { *ptr.add(idx) = *v }705}706})707});708},709}710711// SAFETY: we have written all slots712unsafe { values.set_len(len) }713ChunkedArray::<T>::new_vec(ca.name().clone(), values).into_series()714} else {715// We don't use a mutable bitmap as bits will have race conditions!716// A single byte might alias if we write from single threads.717let mut validity: Vec<bool> = vec![false; len];718let validity_ptr = validity.as_mut_ptr();719let sync_ptr_validity = unsafe { SyncPtr::new(validity_ptr) };720721let n_threads = POOL.current_num_threads();722let offsets = _split_offsets(ca.len(), n_threads);723724match groups {725GroupsType::Idx(groups) => offsets.par_iter().for_each(|(offset, offset_len)| {726let offset = *offset;727let offset_len = *offset_len;728let ca = ca.slice(offset as i64, offset_len);729let groups = &groups.all()[offset..offset + offset_len];730let values_ptr = sync_ptr_values.get();731let validity_ptr = sync_ptr_validity.get();732733ca.iter().zip(groups.iter()).for_each(|(opt_v, g)| {734for idx in g.as_slice() {735let idx = *idx as usize;736debug_assert!(idx < len);737unsafe {738match opt_v {739Some(v) => {740*values_ptr.add(idx) = v;741*validity_ptr.add(idx) = true;742},743None => {744*values_ptr.add(idx) = T::Native::default();745*validity_ptr.add(idx) = false;746},747};748}749}750})751}),752GroupsType::Slice { groups, .. } => {753offsets.par_iter().for_each(|(offset, offset_len)| {754let offset = *offset;755let offset_len = *offset_len;756let ca = ca.slice(offset as i64, offset_len);757let groups = &groups[offset..offset + offset_len];758let values_ptr = sync_ptr_values.get();759let validity_ptr = sync_ptr_validity.get();760761for (opt_v, [start, g_len]) in ca.iter().zip(groups.iter()) {762let start = *start as usize;763let end = start + *g_len as usize;764for idx in start..end {765debug_assert!(idx < len);766unsafe {767match opt_v {768Some(v) => {769*values_ptr.add(idx) = v;770*validity_ptr.add(idx) = true;771},772None => {773*values_ptr.add(idx) = T::Native::default();774*validity_ptr.add(idx) = false;775},776};777}778}779}780})781},782}783// SAFETY: we have written all slots784unsafe { values.set_len(len) }785let validity = Bitmap::from(validity);786let arr = PrimitiveArray::new(787T::get_static_dtype()788.to_physical()789.to_arrow(CompatLevel::newest()),790values.into(),791Some(validity),792);793Series::try_from((ca.name().clone(), arr.boxed())).unwrap()794}795}796797798