Path: blob/main/crates/polars-ops/src/series/ops/rank.rs
6939 views
#![allow(unsafe_op_in_unsafe_fn)]1use arrow::array::BooleanArray;2use arrow::compute::concatenate::concatenate_validities;3use polars_core::prelude::*;4use rand::prelude::*;5#[cfg(feature = "serde")]6use serde::{Deserialize, Serialize};78use crate::prelude::SeriesSealed;910#[derive(Copy, Clone, Debug, PartialEq, Hash)]11#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]12#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]13pub enum RankMethod {14Average,15Min,16Max,17Dense,18Ordinal,19#[cfg(feature = "random")]20Random,21}2223// We might want to add a `nulls_last` or `null_behavior` field.24#[derive(Copy, Clone, Debug, PartialEq, Hash)]25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]26#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]27pub struct RankOptions {28pub method: RankMethod,29pub descending: bool,30}3132impl Default for RankOptions {33fn default() -> Self {34Self {35method: RankMethod::Dense,36descending: false,37}38}39}4041#[cfg(feature = "random")]42fn get_random_seed() -> u64 {43let mut rng = SmallRng::from_os_rng();4445rng.next_u64()46}4748unsafe fn rank_impl<F: FnMut(&mut [IdxSize])>(idxs: &IdxCa, neq: &BooleanArray, mut flush_ties: F) {49let mut ties_indices = Vec::with_capacity(128);50let mut idx_it = idxs.downcast_iter().flat_map(|arr| arr.values_iter());51let Some(first_idx) = idx_it.next() else {52return;53};54ties_indices.push(*first_idx);5556for (eq_idx, idx) in idx_it.enumerate() {57if neq.value_unchecked(eq_idx) {58flush_ties(&mut ties_indices);59ties_indices.clear()60}6162ties_indices.push(*idx);63}64flush_ties(&mut ties_indices);65}6667fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option<u64>) -> Series {68let len = s.len();69let null_count = s.null_count();7071if null_count == len {72let dt = match method {73Average => DataType::Float64,74_ => IDX_DTYPE,75};76return Series::full_null(s.name().clone(), s.len(), &dt);77}7879match len {801 => {81return match method {82Average => Series::new(s.name().clone(), &[1.0f64]),83_ => Series::new(s.name().clone(), &[1 as IdxSize]),84};85},860 => {87return match method {88Average => Float64Chunked::from_slice(s.name().clone(), &[]).into_series(),89_ => IdxCa::from_slice(s.name().clone(), &[]).into_series(),90};91},92_ => {},93}9495if null_count == len {96return match method {97Average => Float64Chunked::full_null(s.name().clone(), len).into_series(),98_ => IdxCa::full_null(s.name().clone(), len).into_series(),99};100}101102let sort_idx_ca = s103.arg_sort(SortOptions {104descending,105nulls_last: true,106..Default::default()107})108.slice(0, len - null_count);109110let validity = concatenate_validities(s.chunks());111112use RankMethod::*;113if let Ordinal = method {114let mut out = vec![0 as IdxSize; s.len()];115let mut rank = 0;116for arr in sort_idx_ca.downcast_iter() {117for i in arr.values_iter() {118out[*i as usize] = rank + 1;119rank += 1;120}121}122IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series()123} else {124let sorted_values = unsafe { s.take_unchecked(&sort_idx_ca) };125let not_consecutive_same = sorted_values126.slice(1, sorted_values.len() - 1)127.not_equal(&sorted_values.slice(0, sorted_values.len() - 1))128.unwrap();129let neq = not_consecutive_same.rechunk();130let neq = neq.downcast_as_array();131132let mut rank = 1;133match method {134#[cfg(feature = "random")]135Random => unsafe {136let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed));137let mut out = vec![0 as IdxSize; s.len()];138rank_impl(&sort_idx_ca, neq, |ties| {139ties.shuffle(&mut rng);140for i in ties {141*out.get_unchecked_mut(*i as usize) = rank;142rank += 1;143}144});145IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series()146},147Average => unsafe {148let mut out = vec![0.0; s.len()];149rank_impl(&sort_idx_ca, neq, |ties| {150let first = rank;151rank += ties.len() as IdxSize;152let last = rank - 1;153let avg = 0.5 * (first as f64 + last as f64);154for i in ties {155*out.get_unchecked_mut(*i as usize) = avg;156}157});158Float64Chunked::from_vec_validity(s.name().clone(), out, validity).into_series()159},160Min => unsafe {161let mut out = vec![0 as IdxSize; s.len()];162rank_impl(&sort_idx_ca, neq, |ties| {163for i in ties.iter() {164*out.get_unchecked_mut(*i as usize) = rank;165}166rank += ties.len() as IdxSize;167});168IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series()169},170Max => unsafe {171let mut out = vec![0 as IdxSize; s.len()];172rank_impl(&sort_idx_ca, neq, |ties| {173rank += ties.len() as IdxSize;174for i in ties {175*out.get_unchecked_mut(*i as usize) = rank - 1;176}177});178IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series()179},180Dense => unsafe {181let mut out = vec![0 as IdxSize; s.len()];182rank_impl(&sort_idx_ca, neq, |ties| {183for i in ties {184*out.get_unchecked_mut(*i as usize) = rank;185}186rank += 1;187});188IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series()189},190Ordinal => unreachable!(),191}192}193}194195pub trait SeriesRank: SeriesSealed {196fn rank(&self, options: RankOptions, seed: Option<u64>) -> Series {197rank(self.as_series(), options.method, options.descending, seed)198}199}200201impl SeriesRank for Series {}202203#[cfg(test)]204mod test {205use super::*;206207#[test]208fn test_rank() -> PolarsResult<()> {209let s = Series::new("a".into(), &[1, 2, 3, 2, 2, 3, 0]);210211let out = rank(&s, RankMethod::Ordinal, false, None)212.idx()?213.into_no_null_iter()214.collect::<Vec<_>>();215assert_eq!(out, &[2 as IdxSize, 3, 6, 4, 5, 7, 1]);216217#[cfg(feature = "random")]218{219let out = rank(&s, RankMethod::Random, false, None)220.idx()?221.into_no_null_iter()222.collect::<Vec<_>>();223assert_eq!(out[0], 2);224assert_eq!(out[6], 1);225assert_eq!(out[1] + out[3] + out[4], 12);226assert_eq!(out[2] + out[5], 13);227assert_ne!(out[1], out[3]);228assert_ne!(out[1], out[4]);229assert_ne!(out[3], out[4]);230}231232let out = rank(&s, RankMethod::Dense, false, None)233.idx()?234.into_no_null_iter()235.collect::<Vec<_>>();236assert_eq!(out, &[2, 3, 4, 3, 3, 4, 1]);237238let out = rank(&s, RankMethod::Max, false, None)239.idx()?240.into_no_null_iter()241.collect::<Vec<_>>();242assert_eq!(out, &[2, 5, 7, 5, 5, 7, 1]);243244let out = rank(&s, RankMethod::Min, false, None)245.idx()?246.into_no_null_iter()247.collect::<Vec<_>>();248assert_eq!(out, &[2, 3, 6, 3, 3, 6, 1]);249250let out = rank(&s, RankMethod::Average, false, None)251.f64()?252.into_no_null_iter()253.collect::<Vec<_>>();254assert_eq!(out, &[2.0f64, 4.0, 6.5, 4.0, 4.0, 6.5, 1.0]);255256let s = Series::new(257"a".into(),258&[Some(1), Some(2), Some(3), Some(2), None, None, Some(0)],259);260261let out = rank(&s, RankMethod::Average, false, None)262.f64()?263.into_iter()264.collect::<Vec<_>>();265266assert_eq!(267out,268&[269Some(2.0f64),270Some(3.5),271Some(5.0),272Some(3.5),273None,274None,275Some(1.0)276]277);278let s = Series::new(279"a".into(),280&[281Some(5),282Some(6),283Some(4),284None,285Some(78),286Some(4),287Some(2),288Some(8),289],290);291let out = rank(&s, RankMethod::Max, false, None)292.idx()?293.into_iter()294.collect::<Vec<_>>();295assert_eq!(296out,297&[298Some(4),299Some(5),300Some(3),301None,302Some(7),303Some(3),304Some(1),305Some(6)306]307);308309Ok(())310}311312#[test]313fn test_rank_all_null() -> PolarsResult<()> {314let s = UInt32Chunked::new("".into(), &[None, None, None]).into_series();315let out = rank(&s, RankMethod::Average, false, None)316.f64()?317.into_iter()318.collect::<Vec<_>>();319assert_eq!(out, &[None, None, None]);320let out = rank(&s, RankMethod::Dense, false, None)321.idx()?322.into_iter()323.collect::<Vec<_>>();324assert_eq!(out, &[None, None, None]);325Ok(())326}327328#[test]329fn test_rank_empty() {330let s = UInt32Chunked::from_slice("".into(), &[]).into_series();331let out = rank(&s, RankMethod::Average, false, None);332assert_eq!(out.dtype(), &DataType::Float64);333let out = rank(&s, RankMethod::Max, false, None);334assert_eq!(out.dtype(), &IDX_DTYPE);335}336337#[test]338fn test_rank_reverse() -> PolarsResult<()> {339let s = Series::new("".into(), &[None, Some(1), Some(1), Some(5), None]);340let out = rank(&s, RankMethod::Dense, true, None)341.idx()?342.into_iter()343.collect::<Vec<_>>();344assert_eq!(out, &[None, Some(2 as IdxSize), Some(2), Some(1), None]);345346Ok(())347}348}349350351