use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicU64, Ordering};
use bytemuck::Pod;
struct VecVTable {
size: usize,
align: usize,
drop_buffer: unsafe fn(*mut (), usize),
}
impl VecVTable {
const fn new<T>() -> Self {
unsafe fn drop_buffer<T>(ptr: *mut (), cap: usize) {
unsafe { drop(Vec::from_raw_parts(ptr.cast::<T>(), 0, cap)) }
}
Self {
size: size_of::<T>(),
align: align_of::<T>(),
drop_buffer: drop_buffer::<T>,
}
}
fn new_static<T>() -> &'static Self {
const { &Self::new::<T>() }
}
}
use crate::ffi::InternalArrowArray;
enum BackingStorage {
Vec {
original_capacity: usize,
vtable: &'static VecVTable,
},
InternalArrowArray(InternalArrowArray),
External,
Leaked,
}
struct SharedStorageInner<T> {
ref_count: AtomicU64,
ptr: *mut T,
length_in_bytes: usize,
backing: BackingStorage,
phantom: PhantomData<T>,
}
unsafe impl<T: Sync + Send> Sync for SharedStorageInner<T> {}
impl<T> SharedStorageInner<T> {
pub fn from_vec(mut v: Vec<T>) -> Self {
let length_in_bytes = v.len() * size_of::<T>();
let original_capacity = v.capacity();
let ptr = v.as_mut_ptr();
core::mem::forget(v);
Self {
ref_count: AtomicU64::new(1),
ptr,
length_in_bytes,
backing: BackingStorage::Vec {
original_capacity,
vtable: VecVTable::new_static::<T>(),
},
phantom: PhantomData,
}
}
}
impl<T> Drop for SharedStorageInner<T> {
fn drop(&mut self) {
match core::mem::replace(&mut self.backing, BackingStorage::External) {
BackingStorage::InternalArrowArray(a) => drop(a),
BackingStorage::Vec {
original_capacity,
vtable,
} => unsafe {
if std::mem::needs_drop::<T>() {
core::ptr::drop_in_place(core::ptr::slice_from_raw_parts_mut(
self.ptr,
self.length_in_bytes / size_of::<T>(),
));
}
if original_capacity > 0 {
(vtable.drop_buffer)(self.ptr.cast(), original_capacity);
}
},
BackingStorage::External | BackingStorage::Leaked => {},
}
}
}
pub struct SharedStorage<T> {
inner: NonNull<SharedStorageInner<T>>,
phantom: PhantomData<SharedStorageInner<T>>,
}
unsafe impl<T: Sync + Send> Send for SharedStorage<T> {}
unsafe impl<T: Sync + Send> Sync for SharedStorage<T> {}
impl<T> Default for SharedStorage<T> {
fn default() -> Self {
Self::empty()
}
}
impl<T> SharedStorage<T> {
const fn empty() -> Self {
assert!(align_of::<T>() <= 1 << 30);
static INNER: SharedStorageInner<()> = SharedStorageInner {
ref_count: AtomicU64::new(1),
ptr: core::ptr::without_provenance_mut(1 << 30),
length_in_bytes: 0,
backing: BackingStorage::Leaked,
phantom: PhantomData,
};
Self {
inner: NonNull::new(&raw const INNER as *mut SharedStorageInner<T>).unwrap(),
phantom: PhantomData,
}
}
pub fn from_static(slice: &'static [T]) -> Self {
#[expect(clippy::manual_slice_size_calculation)]
let length_in_bytes = slice.len() * size_of::<T>();
let ptr = slice.as_ptr().cast_mut();
let inner = SharedStorageInner {
ref_count: AtomicU64::new(1),
ptr,
length_in_bytes,
backing: BackingStorage::External,
phantom: PhantomData,
};
Self {
inner: NonNull::new(Box::into_raw(Box::new(inner))).unwrap(),
phantom: PhantomData,
}
}
pub fn from_vec(v: Vec<T>) -> Self {
Self {
inner: NonNull::new(Box::into_raw(Box::new(SharedStorageInner::from_vec(v)))).unwrap(),
phantom: PhantomData,
}
}
pub fn from_internal_arrow_array(ptr: *const T, len: usize, arr: InternalArrowArray) -> Self {
let inner = SharedStorageInner {
ref_count: AtomicU64::new(1),
ptr: ptr.cast_mut(),
length_in_bytes: len * size_of::<T>(),
backing: BackingStorage::InternalArrowArray(arr),
phantom: PhantomData,
};
Self {
inner: NonNull::new(Box::into_raw(Box::new(inner))).unwrap(),
phantom: PhantomData,
}
}
pub fn leak(&mut self) {
assert!(self.is_exclusive());
unsafe {
let inner = &mut *self.inner.as_ptr();
core::mem::forget(core::mem::replace(
&mut inner.backing,
BackingStorage::Leaked,
));
}
}
}
pub struct SharedStorageAsVecMut<'a, T> {
ss: &'a mut SharedStorage<T>,
vec: ManuallyDrop<Vec<T>>,
}
impl<T> Deref for SharedStorageAsVecMut<'_, T> {
type Target = Vec<T>;
fn deref(&self) -> &Self::Target {
&self.vec
}
}
impl<T> DerefMut for SharedStorageAsVecMut<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.vec
}
}
impl<T> Drop for SharedStorageAsVecMut<'_, T> {
fn drop(&mut self) {
unsafe {
let vec = ManuallyDrop::take(&mut self.vec);
let inner = self.ss.inner.as_ptr();
inner.write(SharedStorageInner::from_vec(vec));
}
}
}
impl<T> SharedStorage<T> {
#[inline(always)]
pub fn len(&self) -> usize {
self.inner().length_in_bytes / size_of::<T>()
}
#[inline(always)]
pub fn as_ptr(&self) -> *const T {
self.inner().ptr
}
#[inline(always)]
pub fn is_exclusive(&mut self) -> bool {
self.inner().ref_count.load(Ordering::Acquire) == 1
}
#[inline(always)]
pub fn refcount(&self) -> u64 {
self.inner().ref_count.load(Ordering::Acquire)
}
pub fn try_as_mut_slice(&mut self) -> Option<&mut [T]> {
self.is_exclusive().then(|| {
let inner = self.inner();
let len = inner.length_in_bytes / size_of::<T>();
unsafe { core::slice::from_raw_parts_mut(inner.ptr, len) }
})
}
pub fn try_take_vec(&mut self) -> Option<Vec<T>> {
if !self.is_exclusive() {
return None;
}
let ret;
unsafe {
let inner = &mut *self.inner.as_ptr();
let BackingStorage::Vec {
original_capacity,
vtable,
} = &mut inner.backing
else {
return None;
};
if vtable.size != size_of::<T>() || vtable.align != align_of::<T>() {
return None;
}
let len = inner.length_in_bytes / size_of::<T>();
ret = Vec::from_raw_parts(inner.ptr, len, *original_capacity);
*original_capacity = 0;
inner.length_in_bytes = 0;
}
Some(ret)
}
pub fn try_as_mut_vec(&mut self) -> Option<SharedStorageAsVecMut<'_, T>> {
Some(SharedStorageAsVecMut {
vec: ManuallyDrop::new(self.try_take_vec()?),
ss: self,
})
}
pub fn try_into_vec(mut self) -> Result<Vec<T>, Self> {
self.try_take_vec().ok_or(self)
}
#[inline(always)]
fn inner(&self) -> &SharedStorageInner<T> {
unsafe { &*self.inner.as_ptr() }
}
#[cold]
unsafe fn drop_slow(&mut self) {
unsafe { drop(Box::from_raw(self.inner.as_ptr())) }
}
}
impl<T: Pod> SharedStorage<T> {
pub fn try_transmute<U: Pod>(self) -> Result<SharedStorage<U>, Self> {
let inner = self.inner();
if !size_of::<T>().is_multiple_of(size_of::<U>())
&& !inner.length_in_bytes.is_multiple_of(size_of::<U>())
{
return Err(self);
}
if !align_of::<T>().is_multiple_of(align_of::<U>()) && !inner.ptr.cast::<U>().is_aligned() {
return Err(self);
}
let storage = SharedStorage {
inner: self.inner.cast(),
phantom: PhantomData,
};
std::mem::forget(self);
Ok(storage)
}
}
impl SharedStorage<u8> {
pub fn bytes_from_pod_vec<T: Pod>(v: Vec<T>) -> Self {
SharedStorage::from_vec(v)
.try_transmute::<u8>()
.unwrap_or_else(|_| unreachable!())
}
}
impl<T> Deref for SharedStorage<T> {
type Target = [T];
#[inline]
fn deref(&self) -> &Self::Target {
unsafe {
let inner = self.inner();
let len = inner.length_in_bytes / size_of::<T>();
core::slice::from_raw_parts(inner.ptr, len)
}
}
}
impl<T> Clone for SharedStorage<T> {
fn clone(&self) -> Self {
let inner = self.inner();
if !matches!(inner.backing, BackingStorage::Leaked) {
inner.ref_count.fetch_add(1, Ordering::Relaxed);
}
Self {
inner: self.inner,
phantom: PhantomData,
}
}
}
impl<T> Drop for SharedStorage<T> {
fn drop(&mut self) {
let inner = self.inner();
if matches!(inner.backing, BackingStorage::Leaked) {
return;
}
if inner.ref_count.fetch_sub(1, Ordering::Release) == 1 {
std::sync::atomic::fence(Ordering::Acquire);
unsafe {
self.drop_slow();
}
}
}
}