Path: blob/main/crates/polars-ops/src/chunked_array/top_k.rs
6939 views
use arrow::array::{BinaryViewArray, BooleanArray, PrimitiveArray, StaticArray, View};1use arrow::bitmap::{Bitmap, BitmapBuilder};2use polars_core::chunked_array::ops::sort::arg_bottom_k::_arg_bottom_k;3use polars_core::prelude::*;4use polars_core::series::IsSorted;5use polars_core::{POOL, downcast_as_macro_arg_physical};6use polars_utils::total_ord::TotalOrd;78fn first_n_valid_mask(num_valid: usize, out_len: usize) -> Option<Bitmap> {9if num_valid < out_len {10let mut bm = BitmapBuilder::with_capacity(out_len);11bm.extend_constant(num_valid, true);12bm.extend_constant(out_len - num_valid, false);13Some(bm.freeze())14} else {15None16}17}1819fn top_k_bool_impl(20ca: &ChunkedArray<BooleanType>,21k: usize,22descending: bool,23) -> ChunkedArray<BooleanType> {24if k >= ca.len() && ca.null_count() == 0 {25return ca.clone();26}2728let null_count = ca.null_count();29let non_null_count = ca.len() - ca.null_count();30let true_count = ca.sum().unwrap() as usize;31let false_count = non_null_count - true_count;32let mut out_len = k.min(ca.len());33let validity = first_n_valid_mask(non_null_count, out_len);3435// Logical sequence of physical bits.36let sequence = if descending {37[38(false_count, false),39(true_count, true),40(null_count, false),41]42} else {43[44(true_count, true),45(false_count, false),46(null_count, false),47]48};4950let mut bm = BitmapBuilder::with_capacity(out_len);51for (n, value) in sequence {52if out_len == 0 {53break;54}55let extra = out_len.min(n);56bm.extend_constant(extra, value);57out_len -= extra;58}5960let arr = BooleanArray::from_data_default(bm.freeze(), validity);61ChunkedArray::with_chunk_like(ca, arr)62}6364fn top_k_num_impl<T>(ca: &ChunkedArray<T>, k: usize, descending: bool) -> ChunkedArray<T>65where66T: PolarsNumericType,67{68if k >= ca.len() && ca.null_count() == 0 {69return ca.clone();70}7172// Get rid of all the nulls and transform into Vec<T::Native>.73let mut nnca = ca.drop_nulls();74nnca.rechunk_mut();75let chunk = nnca.downcast_into_iter().next().unwrap();76let (_, buffer, _) = chunk.into_inner();77let mut vec = buffer.make_mut();7879// Partition.80if k < vec.len() {81if descending {82vec.select_nth_unstable_by(k, TotalOrd::tot_cmp);83} else {84vec.select_nth_unstable_by(k, |a, b| TotalOrd::tot_cmp(b, a));85}86}8788// Reconstruct output (with nulls at the end).89let out_len = k.min(ca.len());90let non_null_count = ca.len() - ca.null_count();91vec.resize(out_len, T::Native::default());92let validity = first_n_valid_mask(non_null_count, out_len);9394let arr = PrimitiveArray::from_vec(vec).with_validity_typed(validity);95ChunkedArray::with_chunk_like(ca, arr)96}9798fn top_k_binary_impl(99ca: &ChunkedArray<BinaryType>,100k: usize,101descending: bool,102) -> ChunkedArray<BinaryType> {103if k >= ca.len() && ca.null_count() == 0 {104return ca.clone();105}106107// Get rid of all the nulls and transform into mutable views.108let mut nnca = ca.drop_nulls();109nnca.rechunk_mut();110let chunk = nnca.downcast_into_iter().next().unwrap();111let buffers = chunk.data_buffers().clone();112let mut views = chunk.into_views();113114// Partition.115if k < views.len() {116if descending {117views.select_nth_unstable_by(k, |a, b| unsafe {118let a_sl = a.get_slice_unchecked(&buffers);119let b_sl = b.get_slice_unchecked(&buffers);120a_sl.cmp(b_sl)121});122} else {123views.select_nth_unstable_by(k, |a, b| unsafe {124let a_sl = a.get_slice_unchecked(&buffers);125let b_sl = b.get_slice_unchecked(&buffers);126b_sl.cmp(a_sl)127});128}129}130131// Reconstruct output (with nulls at the end).132let out_len = k.min(ca.len());133let non_null_count = ca.len() - ca.null_count();134views.resize(out_len, View::default());135let validity = first_n_valid_mask(non_null_count, out_len);136137let arr = unsafe {138BinaryViewArray::new_unchecked_unknown_md(139ArrowDataType::BinaryView,140views.into(),141buffers,142validity,143None,144)145};146ChunkedArray::with_chunk_like(ca, arr)147}148149pub fn top_k(s: &[Column], descending: bool) -> PolarsResult<Column> {150fn extract_target_and_k(s: &[Column]) -> PolarsResult<(usize, &Column)> {151let k_s = &s[1];152polars_ensure!(153k_s.len() == 1,154ComputeError: "`k` must be a single value for `top_k`."155);156157let Some(k) = k_s.cast(&IDX_DTYPE)?.idx()?.get(0) else {158polars_bail!(ComputeError: "`k` must be set for `top_k`")159};160161let src = &s[0];162Ok((k as usize, src))163}164165let (k, src) = extract_target_and_k(s)?;166167if src.is_empty() {168return Ok(src.clone());169}170171let sorted_flag = src.is_sorted_flag();172let is_sorted = match src.is_sorted_flag() {173IsSorted::Ascending => true,174IsSorted::Descending => true,175IsSorted::Not => false,176};177if is_sorted {178let out_len = k.min(src.len());179let ignored_len = src.len() - out_len;180let slice_at_start = (sorted_flag == IsSorted::Ascending) == descending;181let nulls_at_start = src.get(0).unwrap() == AnyValue::Null;182let offset = if nulls_at_start == slice_at_start {183src.null_count().min(ignored_len)184} else {1850186};187188return if slice_at_start {189Ok(src.slice(offset as i64, out_len))190} else {191Ok(src.slice(-(offset as i64) - (out_len as i64), out_len))192};193}194195let origin_dtype = src.dtype();196197let s = src.to_physical_repr();198199match s.dtype() {200DataType::Boolean => Ok(top_k_bool_impl(s.bool().unwrap(), k, descending).into_column()),201DataType::String => {202let ca = top_k_binary_impl(&s.str().unwrap().as_binary(), k, descending);203let ca = unsafe { ca.to_string_unchecked() };204Ok(ca.into_column())205},206DataType::Binary => Ok(top_k_binary_impl(s.binary().unwrap(), k, descending).into_column()),207DataType::Null => Ok(src.slice(0, k)),208dt if dt.is_primitive_numeric() => {209macro_rules! dispatch {210($ca:expr) => {{ top_k_num_impl($ca, k, descending).into_column() }};211}212unsafe {213downcast_as_macro_arg_physical!(&s, dispatch).from_physical_unchecked(origin_dtype)214}215},216_ => {217// Fallback to more generic impl.218top_k_by_impl(k, src, std::slice::from_ref(src), vec![descending])219},220}221}222223pub fn top_k_by(s: &[Column], descending: Vec<bool>) -> PolarsResult<Column> {224/// Return (k, src, by)225fn extract_parameters(s: &[Column]) -> PolarsResult<(usize, &Column, &[Column])> {226let k_s = &s[1];227228polars_ensure!(229k_s.len() == 1,230ComputeError: "`k` must be a single value for `top_k`."231);232233let Some(k) = k_s.cast(&IDX_DTYPE)?.idx()?.get(0) else {234polars_bail!(ComputeError: "`k` must be set for `top_k`")235};236237let src = &s[0];238239let by = &s[2..];240241Ok((k as usize, src, by))242}243244let (k, src, by) = extract_parameters(s)?;245246if src.is_empty() {247return Ok(src.clone());248}249250if by.first().map(|x| x.is_empty()).unwrap_or(false) {251return Ok(src.clone());252}253254for s in by {255if s.len() != src.len() {256polars_bail!(ComputeError: "`by` column's ({}) length ({}) should have the same length as the source column length ({}) in `top_k`", s.name(), s.len(), src.len())257}258}259260top_k_by_impl(k, src, by, descending)261}262263fn top_k_by_impl(264k: usize,265src: &Column,266by: &[Column],267descending: Vec<bool>,268) -> PolarsResult<Column> {269if src.is_empty() {270return Ok(src.clone());271}272273let multithreaded = k >= 10000 && POOL.current_num_threads() > 1;274let mut sort_options = SortMultipleOptions {275descending: descending.into_iter().map(|x| !x).collect(),276nulls_last: vec![true; by.len()],277multithreaded,278maintain_order: false,279limit: None,280};281282let idx = _arg_bottom_k(k, by, &mut sort_options)?;283284let result = unsafe {285src.as_materialized_series()286.take_unchecked(&idx.into_inner())287};288Ok(result.into())289}290291292