Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-ops/src/series/ops/rank.rs
6939 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use arrow::array::BooleanArray;
3
use arrow::compute::concatenate::concatenate_validities;
4
use polars_core::prelude::*;
5
use rand::prelude::*;
6
#[cfg(feature = "serde")]
7
use serde::{Deserialize, Serialize};
8
9
use crate::prelude::SeriesSealed;
10
11
#[derive(Copy, Clone, Debug, PartialEq, Hash)]
12
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
13
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
14
pub enum RankMethod {
15
Average,
16
Min,
17
Max,
18
Dense,
19
Ordinal,
20
#[cfg(feature = "random")]
21
Random,
22
}
23
24
// We might want to add a `nulls_last` or `null_behavior` field.
25
#[derive(Copy, Clone, Debug, PartialEq, Hash)]
26
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
27
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
28
pub struct RankOptions {
29
pub method: RankMethod,
30
pub descending: bool,
31
}
32
33
impl Default for RankOptions {
34
fn default() -> Self {
35
Self {
36
method: RankMethod::Dense,
37
descending: false,
38
}
39
}
40
}
41
42
#[cfg(feature = "random")]
43
fn get_random_seed() -> u64 {
44
let mut rng = SmallRng::from_os_rng();
45
46
rng.next_u64()
47
}
48
49
unsafe fn rank_impl<F: FnMut(&mut [IdxSize])>(idxs: &IdxCa, neq: &BooleanArray, mut flush_ties: F) {
50
let mut ties_indices = Vec::with_capacity(128);
51
let mut idx_it = idxs.downcast_iter().flat_map(|arr| arr.values_iter());
52
let Some(first_idx) = idx_it.next() else {
53
return;
54
};
55
ties_indices.push(*first_idx);
56
57
for (eq_idx, idx) in idx_it.enumerate() {
58
if neq.value_unchecked(eq_idx) {
59
flush_ties(&mut ties_indices);
60
ties_indices.clear()
61
}
62
63
ties_indices.push(*idx);
64
}
65
flush_ties(&mut ties_indices);
66
}
67
68
fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option<u64>) -> Series {
69
let len = s.len();
70
let null_count = s.null_count();
71
72
if null_count == len {
73
let dt = match method {
74
Average => DataType::Float64,
75
_ => IDX_DTYPE,
76
};
77
return Series::full_null(s.name().clone(), s.len(), &dt);
78
}
79
80
match len {
81
1 => {
82
return match method {
83
Average => Series::new(s.name().clone(), &[1.0f64]),
84
_ => Series::new(s.name().clone(), &[1 as IdxSize]),
85
};
86
},
87
0 => {
88
return match method {
89
Average => Float64Chunked::from_slice(s.name().clone(), &[]).into_series(),
90
_ => IdxCa::from_slice(s.name().clone(), &[]).into_series(),
91
};
92
},
93
_ => {},
94
}
95
96
if null_count == len {
97
return match method {
98
Average => Float64Chunked::full_null(s.name().clone(), len).into_series(),
99
_ => IdxCa::full_null(s.name().clone(), len).into_series(),
100
};
101
}
102
103
let sort_idx_ca = s
104
.arg_sort(SortOptions {
105
descending,
106
nulls_last: true,
107
..Default::default()
108
})
109
.slice(0, len - null_count);
110
111
let validity = concatenate_validities(s.chunks());
112
113
use RankMethod::*;
114
if let Ordinal = method {
115
let mut out = vec![0 as IdxSize; s.len()];
116
let mut rank = 0;
117
for arr in sort_idx_ca.downcast_iter() {
118
for i in arr.values_iter() {
119
out[*i as usize] = rank + 1;
120
rank += 1;
121
}
122
}
123
IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series()
124
} else {
125
let sorted_values = unsafe { s.take_unchecked(&sort_idx_ca) };
126
let not_consecutive_same = sorted_values
127
.slice(1, sorted_values.len() - 1)
128
.not_equal(&sorted_values.slice(0, sorted_values.len() - 1))
129
.unwrap();
130
let neq = not_consecutive_same.rechunk();
131
let neq = neq.downcast_as_array();
132
133
let mut rank = 1;
134
match method {
135
#[cfg(feature = "random")]
136
Random => unsafe {
137
let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed));
138
let mut out = vec![0 as IdxSize; s.len()];
139
rank_impl(&sort_idx_ca, neq, |ties| {
140
ties.shuffle(&mut rng);
141
for i in ties {
142
*out.get_unchecked_mut(*i as usize) = rank;
143
rank += 1;
144
}
145
});
146
IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series()
147
},
148
Average => unsafe {
149
let mut out = vec![0.0; s.len()];
150
rank_impl(&sort_idx_ca, neq, |ties| {
151
let first = rank;
152
rank += ties.len() as IdxSize;
153
let last = rank - 1;
154
let avg = 0.5 * (first as f64 + last as f64);
155
for i in ties {
156
*out.get_unchecked_mut(*i as usize) = avg;
157
}
158
});
159
Float64Chunked::from_vec_validity(s.name().clone(), out, validity).into_series()
160
},
161
Min => unsafe {
162
let mut out = vec![0 as IdxSize; s.len()];
163
rank_impl(&sort_idx_ca, neq, |ties| {
164
for i in ties.iter() {
165
*out.get_unchecked_mut(*i as usize) = rank;
166
}
167
rank += ties.len() as IdxSize;
168
});
169
IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series()
170
},
171
Max => unsafe {
172
let mut out = vec![0 as IdxSize; s.len()];
173
rank_impl(&sort_idx_ca, neq, |ties| {
174
rank += ties.len() as IdxSize;
175
for i in ties {
176
*out.get_unchecked_mut(*i as usize) = rank - 1;
177
}
178
});
179
IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series()
180
},
181
Dense => unsafe {
182
let mut out = vec![0 as IdxSize; s.len()];
183
rank_impl(&sort_idx_ca, neq, |ties| {
184
for i in ties {
185
*out.get_unchecked_mut(*i as usize) = rank;
186
}
187
rank += 1;
188
});
189
IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series()
190
},
191
Ordinal => unreachable!(),
192
}
193
}
194
}
195
196
pub trait SeriesRank: SeriesSealed {
197
fn rank(&self, options: RankOptions, seed: Option<u64>) -> Series {
198
rank(self.as_series(), options.method, options.descending, seed)
199
}
200
}
201
202
impl SeriesRank for Series {}
203
204
#[cfg(test)]
205
mod test {
206
use super::*;
207
208
#[test]
209
fn test_rank() -> PolarsResult<()> {
210
let s = Series::new("a".into(), &[1, 2, 3, 2, 2, 3, 0]);
211
212
let out = rank(&s, RankMethod::Ordinal, false, None)
213
.idx()?
214
.into_no_null_iter()
215
.collect::<Vec<_>>();
216
assert_eq!(out, &[2 as IdxSize, 3, 6, 4, 5, 7, 1]);
217
218
#[cfg(feature = "random")]
219
{
220
let out = rank(&s, RankMethod::Random, false, None)
221
.idx()?
222
.into_no_null_iter()
223
.collect::<Vec<_>>();
224
assert_eq!(out[0], 2);
225
assert_eq!(out[6], 1);
226
assert_eq!(out[1] + out[3] + out[4], 12);
227
assert_eq!(out[2] + out[5], 13);
228
assert_ne!(out[1], out[3]);
229
assert_ne!(out[1], out[4]);
230
assert_ne!(out[3], out[4]);
231
}
232
233
let out = rank(&s, RankMethod::Dense, false, None)
234
.idx()?
235
.into_no_null_iter()
236
.collect::<Vec<_>>();
237
assert_eq!(out, &[2, 3, 4, 3, 3, 4, 1]);
238
239
let out = rank(&s, RankMethod::Max, false, None)
240
.idx()?
241
.into_no_null_iter()
242
.collect::<Vec<_>>();
243
assert_eq!(out, &[2, 5, 7, 5, 5, 7, 1]);
244
245
let out = rank(&s, RankMethod::Min, false, None)
246
.idx()?
247
.into_no_null_iter()
248
.collect::<Vec<_>>();
249
assert_eq!(out, &[2, 3, 6, 3, 3, 6, 1]);
250
251
let out = rank(&s, RankMethod::Average, false, None)
252
.f64()?
253
.into_no_null_iter()
254
.collect::<Vec<_>>();
255
assert_eq!(out, &[2.0f64, 4.0, 6.5, 4.0, 4.0, 6.5, 1.0]);
256
257
let s = Series::new(
258
"a".into(),
259
&[Some(1), Some(2), Some(3), Some(2), None, None, Some(0)],
260
);
261
262
let out = rank(&s, RankMethod::Average, false, None)
263
.f64()?
264
.into_iter()
265
.collect::<Vec<_>>();
266
267
assert_eq!(
268
out,
269
&[
270
Some(2.0f64),
271
Some(3.5),
272
Some(5.0),
273
Some(3.5),
274
None,
275
None,
276
Some(1.0)
277
]
278
);
279
let s = Series::new(
280
"a".into(),
281
&[
282
Some(5),
283
Some(6),
284
Some(4),
285
None,
286
Some(78),
287
Some(4),
288
Some(2),
289
Some(8),
290
],
291
);
292
let out = rank(&s, RankMethod::Max, false, None)
293
.idx()?
294
.into_iter()
295
.collect::<Vec<_>>();
296
assert_eq!(
297
out,
298
&[
299
Some(4),
300
Some(5),
301
Some(3),
302
None,
303
Some(7),
304
Some(3),
305
Some(1),
306
Some(6)
307
]
308
);
309
310
Ok(())
311
}
312
313
#[test]
314
fn test_rank_all_null() -> PolarsResult<()> {
315
let s = UInt32Chunked::new("".into(), &[None, None, None]).into_series();
316
let out = rank(&s, RankMethod::Average, false, None)
317
.f64()?
318
.into_iter()
319
.collect::<Vec<_>>();
320
assert_eq!(out, &[None, None, None]);
321
let out = rank(&s, RankMethod::Dense, false, None)
322
.idx()?
323
.into_iter()
324
.collect::<Vec<_>>();
325
assert_eq!(out, &[None, None, None]);
326
Ok(())
327
}
328
329
#[test]
330
fn test_rank_empty() {
331
let s = UInt32Chunked::from_slice("".into(), &[]).into_series();
332
let out = rank(&s, RankMethod::Average, false, None);
333
assert_eq!(out.dtype(), &DataType::Float64);
334
let out = rank(&s, RankMethod::Max, false, None);
335
assert_eq!(out.dtype(), &IDX_DTYPE);
336
}
337
338
#[test]
339
fn test_rank_reverse() -> PolarsResult<()> {
340
let s = Series::new("".into(), &[None, Some(1), Some(1), Some(5), None]);
341
let out = rank(&s, RankMethod::Dense, true, None)
342
.idx()?
343
.into_iter()
344
.collect::<Vec<_>>();
345
assert_eq!(out, &[None, Some(2 as IdxSize), Some(2), Some(1), None]);
346
347
Ok(())
348
}
349
}
350
351