use std::cmp::Ordering;
use std::mem::MaybeUninit;
use num_traits::FromPrimitive;
use rayon::ThreadPool;
use rayon::prelude::*;
use crate::IdxSize;
use crate::total_ord::TotalOrd;
#[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
pub unsafe fn perfect_sort(pool: &ThreadPool, idx: &[(IdxSize, IdxSize)], out: &mut Vec<IdxSize>) {
let chunk_size = std::cmp::max(
idx.len() / pool.current_num_threads(),
pool.current_num_threads(),
);
out.reserve(idx.len());
let ptr = out.as_mut_ptr() as *const IdxSize as usize;
pool.install(|| {
idx.par_chunks(chunk_size).for_each(|indices| {
let ptr = ptr as *mut IdxSize;
for (idx_val, idx_location) in indices {
unsafe { *ptr.add(*idx_location as usize) = *idx_val };
}
});
});
unsafe { out.set_len(idx.len()) };
}
#[cfg(all(not(target_os = "emscripten"), target_family = "wasm"))]
pub unsafe fn perfect_sort(
pool: &crate::wasm::Pool,
idx: &[(IdxSize, IdxSize)],
out: &mut Vec<IdxSize>,
) {
let chunk_size = std::cmp::max(
idx.len() / pool.current_num_threads(),
pool.current_num_threads(),
);
out.reserve(idx.len());
let ptr = out.as_mut_ptr() as *const IdxSize as usize;
pool.install(|| {
idx.par_chunks(chunk_size).for_each(|indices| {
let ptr = ptr as *mut IdxSize;
for (idx_val, idx_location) in indices {
*ptr.add(*idx_location as usize) = *idx_val;
}
});
});
out.set_len(idx.len());
}
unsafe fn assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
}
pub fn arg_sort_ascending<'a, T: TotalOrd + Copy + 'a, Idx, I: IntoIterator<Item = T>>(
v: I,
scratch: &'a mut Vec<u8>,
n: usize,
) -> &'a mut [Idx]
where
Idx: FromPrimitive + Copy,
{
debug_assert_eq!(align_of::<T>(), align_of::<(T, Idx)>());
let size = size_of::<(T, Idx)>();
let upper_bound = size * n + size;
scratch.reserve(upper_bound);
let scratch_slice = unsafe {
let cap_slice = scratch.spare_capacity_mut();
let (_, scratch_slice, _) = cap_slice.align_to_mut::<MaybeUninit<(T, Idx)>>();
&mut scratch_slice[..n]
};
for ((i, v), dst) in v.into_iter().enumerate().zip(scratch_slice.iter_mut()) {
*dst = MaybeUninit::new((v, Idx::from_usize(i).unwrap()));
}
debug_assert_eq!(n, scratch_slice.len());
let scratch_slice = unsafe { assume_init_mut(scratch_slice) };
scratch_slice.sort_by(|key1, key2| key1.0.tot_cmp(&key2.0));
unsafe {
let src = scratch_slice.as_ptr();
let (_, scratch_slice_aligned_to_idx, _) = scratch_slice.align_to_mut::<Idx>();
let dst = scratch_slice_aligned_to_idx.as_mut_ptr();
for i in 0..n {
dst.add(i).write((*src.add(i)).1);
}
&mut scratch_slice_aligned_to_idx[..n]
}
}
#[derive(PartialEq, Eq, Clone, Hash)]
#[repr(transparent)]
pub struct ReorderWithNulls<T, const DESCENDING: bool, const NULLS_LAST: bool>(pub Option<T>);
impl<T: PartialOrd, const DESCENDING: bool, const NULLS_LAST: bool> PartialOrd
for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match (&self.0, &other.0) {
(None, None) => Some(Ordering::Equal),
(None, Some(_)) => {
if NULLS_LAST {
Some(Ordering::Greater)
} else {
Some(Ordering::Less)
}
},
(Some(_), None) => {
if NULLS_LAST {
Some(Ordering::Less)
} else {
Some(Ordering::Greater)
}
},
(Some(l), Some(r)) => {
if DESCENDING {
r.partial_cmp(l)
} else {
l.partial_cmp(r)
}
},
}
}
}
impl<T: Ord, const DESCENDING: bool, const NULLS_LAST: bool> Ord
for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
{
fn cmp(&self, other: &Self) -> Ordering {
match (&self.0, &other.0) {
(None, None) => Ordering::Equal,
(None, Some(_)) => {
if NULLS_LAST {
Ordering::Greater
} else {
Ordering::Less
}
},
(Some(_), None) => {
if NULLS_LAST {
Ordering::Less
} else {
Ordering::Greater
}
},
(Some(l), Some(r)) => {
if DESCENDING {
r.cmp(l)
} else {
l.cmp(r)
}
},
}
}
}