Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-utils/src/sort.rs
6939 views
1
use std::cmp::Ordering;
2
use std::mem::MaybeUninit;
3
4
use num_traits::FromPrimitive;
5
use rayon::ThreadPool;
6
use rayon::prelude::*;
7
8
use crate::IdxSize;
9
use crate::total_ord::TotalOrd;
10
11
/// This is a perfect sort particularly useful for an arg_sort of an arg_sort
12
/// The second arg_sort sorts indices from `0` to `len` so can be just assigned to the
13
/// new index location.
14
///
15
/// Besides that we know that all indices are unique and thus not alias so we can parallelize.
16
///
17
/// This sort does not sort in place and will allocate.
18
///
19
/// - The right indices are used for sorting
20
/// - The left indices are placed at the location right points to.
21
///
22
/// # Safety
23
/// The caller must ensure that the right indexes for `&[(_, IdxSize)]` are integers ranging from `0..idx.len`
24
#[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
25
pub unsafe fn perfect_sort(pool: &ThreadPool, idx: &[(IdxSize, IdxSize)], out: &mut Vec<IdxSize>) {
26
let chunk_size = std::cmp::max(
27
idx.len() / pool.current_num_threads(),
28
pool.current_num_threads(),
29
);
30
31
out.reserve(idx.len());
32
let ptr = out.as_mut_ptr() as *const IdxSize as usize;
33
34
pool.install(|| {
35
idx.par_chunks(chunk_size).for_each(|indices| {
36
let ptr = ptr as *mut IdxSize;
37
for (idx_val, idx_location) in indices {
38
// SAFETY:
39
// idx_location is in bounds by invariant of this function
40
// and we ensured we have at least `idx.len()` capacity
41
unsafe { *ptr.add(*idx_location as usize) = *idx_val };
42
}
43
});
44
});
45
// SAFETY:
46
// all elements are written
47
unsafe { out.set_len(idx.len()) };
48
}
49
50
// wasm alternative with different signature
51
#[cfg(all(not(target_os = "emscripten"), target_family = "wasm"))]
52
pub unsafe fn perfect_sort(
53
pool: &crate::wasm::Pool,
54
idx: &[(IdxSize, IdxSize)],
55
out: &mut Vec<IdxSize>,
56
) {
57
let chunk_size = std::cmp::max(
58
idx.len() / pool.current_num_threads(),
59
pool.current_num_threads(),
60
);
61
62
out.reserve(idx.len());
63
let ptr = out.as_mut_ptr() as *const IdxSize as usize;
64
65
pool.install(|| {
66
idx.par_chunks(chunk_size).for_each(|indices| {
67
let ptr = ptr as *mut IdxSize;
68
for (idx_val, idx_location) in indices {
69
// SAFETY:
70
// idx_location is in bounds by invariant of this function
71
// and we ensured we have at least `idx.len()` capacity
72
*ptr.add(*idx_location as usize) = *idx_val;
73
}
74
});
75
});
76
// SAFETY:
77
// all elements are written
78
out.set_len(idx.len());
79
}
80
81
unsafe fn assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
82
unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
83
}
84
85
pub fn arg_sort_ascending<'a, T: TotalOrd + Copy + 'a, Idx, I: IntoIterator<Item = T>>(
86
v: I,
87
scratch: &'a mut Vec<u8>,
88
n: usize,
89
) -> &'a mut [Idx]
90
where
91
Idx: FromPrimitive + Copy,
92
{
93
// Needed to be able to write back to back in the same buffer.
94
debug_assert_eq!(align_of::<T>(), align_of::<(T, Idx)>());
95
let size = size_of::<(T, Idx)>();
96
let upper_bound = size * n + size;
97
scratch.reserve(upper_bound);
98
let scratch_slice = unsafe {
99
let cap_slice = scratch.spare_capacity_mut();
100
let (_, scratch_slice, _) = cap_slice.align_to_mut::<MaybeUninit<(T, Idx)>>();
101
&mut scratch_slice[..n]
102
};
103
104
for ((i, v), dst) in v.into_iter().enumerate().zip(scratch_slice.iter_mut()) {
105
*dst = MaybeUninit::new((v, Idx::from_usize(i).unwrap()));
106
}
107
debug_assert_eq!(n, scratch_slice.len());
108
109
let scratch_slice = unsafe { assume_init_mut(scratch_slice) };
110
scratch_slice.sort_by(|key1, key2| key1.0.tot_cmp(&key2.0));
111
112
// now we write the indexes in the same array.
113
// So from <T, Idxsize> to <IdxSize>
114
unsafe {
115
let src = scratch_slice.as_ptr();
116
117
let (_, scratch_slice_aligned_to_idx, _) = scratch_slice.align_to_mut::<Idx>();
118
119
let dst = scratch_slice_aligned_to_idx.as_mut_ptr();
120
121
for i in 0..n {
122
dst.add(i).write((*src.add(i)).1);
123
}
124
125
&mut scratch_slice_aligned_to_idx[..n]
126
}
127
}
128
129
#[derive(PartialEq, Eq, Clone, Hash)]
130
#[repr(transparent)]
131
pub struct ReorderWithNulls<T, const DESCENDING: bool, const NULLS_LAST: bool>(pub Option<T>);
132
133
impl<T: PartialOrd, const DESCENDING: bool, const NULLS_LAST: bool> PartialOrd
134
for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
135
{
136
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
137
match (&self.0, &other.0) {
138
(None, None) => Some(Ordering::Equal),
139
(None, Some(_)) => {
140
if NULLS_LAST {
141
Some(Ordering::Greater)
142
} else {
143
Some(Ordering::Less)
144
}
145
},
146
(Some(_), None) => {
147
if NULLS_LAST {
148
Some(Ordering::Less)
149
} else {
150
Some(Ordering::Greater)
151
}
152
},
153
(Some(l), Some(r)) => {
154
if DESCENDING {
155
r.partial_cmp(l)
156
} else {
157
l.partial_cmp(r)
158
}
159
},
160
}
161
}
162
}
163
164
impl<T: Ord, const DESCENDING: bool, const NULLS_LAST: bool> Ord
165
for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
166
{
167
fn cmp(&self, other: &Self) -> Ordering {
168
match (&self.0, &other.0) {
169
(None, None) => Ordering::Equal,
170
(None, Some(_)) => {
171
if NULLS_LAST {
172
Ordering::Greater
173
} else {
174
Ordering::Less
175
}
176
},
177
(Some(_), None) => {
178
if NULLS_LAST {
179
Ordering::Less
180
} else {
181
Ordering::Greater
182
}
183
},
184
(Some(l), Some(r)) => {
185
if DESCENDING {
186
r.cmp(l)
187
} else {
188
l.cmp(r)
189
}
190
},
191
}
192
}
193
}
194
195