Path: blob/main/crates/polars-compute/src/cast/primitive_to.rs
8446 views
use std::hash::Hash;12use arrow::array::*;3use arrow::bitmap::{Bitmap, BitmapBuilder};4use arrow::compute::arity::unary;5use arrow::datatypes::{ArrowDataType, TimeUnit};6use arrow::offset::{Offset, Offsets};7use arrow::types::NativeType;8use num_traits::AsPrimitive;9#[cfg(feature = "dtype-decimal")]10use num_traits::{Float, ToPrimitive};11use polars_error::PolarsResult;12use polars_utils::float16::pf16;13use polars_utils::pl_str::PlSmallStr;14use polars_utils::vec::PushUnchecked;1516use super::CastOptionsImpl;17use super::temporal::*;18#[cfg(feature = "dtype-decimal")]19use crate::decimal::{dec128_verify_prec_scale, f64_to_dec128, i128_to_dec128};2021pub trait SerPrimitive {22fn write(f: &mut Vec<u8>, val: Self) -> usize23where24Self: Sized;25}2627macro_rules! impl_ser_primitive {28($ptype:ident) => {29impl SerPrimitive for $ptype {30fn write(f: &mut Vec<u8>, val: Self) -> usize31where32Self: Sized,33{34let mut buffer = itoa::Buffer::new();35let value = buffer.format(val);36f.extend_from_slice(value.as_bytes());37value.len()38}39}40};41}4243impl_ser_primitive!(i8);44impl_ser_primitive!(i16);45impl_ser_primitive!(i32);46impl_ser_primitive!(i64);47impl_ser_primitive!(i128);48impl_ser_primitive!(u8);49impl_ser_primitive!(u16);50impl_ser_primitive!(u32);51impl_ser_primitive!(u64);52impl_ser_primitive!(u128);5354impl SerPrimitive for pf16 {55fn write(f: &mut Vec<u8>, val: Self) -> usize56where57Self: Sized,58{59f32::write(f, AsPrimitive::<f32>::as_(val))60}61}6263impl SerPrimitive for f32 {64fn write(f: &mut Vec<u8>, val: Self) -> usize65where66Self: Sized,67{68let mut buffer = zmij::Buffer::new();69let value = buffer.format(val);70f.extend_from_slice(value.as_bytes());71value.len()72}73}7475impl SerPrimitive for f64 {76fn write(f: &mut Vec<u8>, val: Self) -> usize77where78Self: Sized,79{80let mut buffer = zmij::Buffer::new();81let value = buffer.format(val);82f.extend_from_slice(value.as_bytes());83value.len()84}85}8687fn fallible_unary<I, F, G, O>(88array: &PrimitiveArray<I>,89op: F,90fail: G,91dtype: ArrowDataType,92) -> PrimitiveArray<O>93where94I: NativeType,95O: NativeType,96F: Fn(I) -> O,97G: Fn(I) -> bool,98{99let values = array.values();100let mut out = Vec::with_capacity(array.len());101let mut i = 0;102103while i < array.len() && !fail(values[i]) {104// SAFETY: We allocated enough before.105unsafe { out.push_unchecked(op(values[i])) };106i += 1;107}108109if out.len() == array.len() {110return PrimitiveArray::<O>::new(dtype, out.into(), array.validity().cloned());111}112113let mut validity = BitmapBuilder::with_capacity(array.len());114validity.extend_constant(out.len(), true);115116for &value in &values[out.len()..] {117// SAFETY: We allocated enough before.118unsafe {119out.push_unchecked(op(value));120validity.push_unchecked(!fail(value));121}122}123124debug_assert_eq!(out.len(), array.len());125debug_assert_eq!(validity.len(), array.len());126127let validity = validity.freeze();128let validity = match array.validity() {129None => validity,130Some(arr_validity) => arrow::bitmap::and(&validity, arr_validity),131};132133PrimitiveArray::<O>::new(dtype, out.into(), Some(validity))134}135136fn primitive_to_values_and_offsets<T: NativeType + SerPrimitive, O: Offset>(137from: &PrimitiveArray<T>,138) -> (Vec<u8>, Offsets<O>) {139let mut values: Vec<u8> = Vec::with_capacity(from.len());140let mut offsets: Vec<O> = Vec::with_capacity(from.len() + 1);141offsets.push(O::default());142143let mut offset: usize = 0;144145unsafe {146for &x in from.values().iter() {147let len = T::write(&mut values, x);148149offset += len;150offsets.push(O::from_as_usize(offset));151}152values.set_len(offset);153values.shrink_to_fit();154// SAFETY: offsets _are_ monotonically increasing155let offsets = Offsets::new_unchecked(offsets);156157(values, offsets)158}159}160161/// Returns a [`BooleanArray`] where every element is different from zero.162/// Validity is preserved.163pub fn primitive_to_boolean<T: NativeType>(164from: &PrimitiveArray<T>,165to_type: ArrowDataType,166) -> BooleanArray {167let iter = from.values().iter().map(|v| *v != T::default());168let values = Bitmap::from_trusted_len_iter(iter);169170BooleanArray::new(to_type, values, from.validity().cloned())171}172173pub(super) fn primitive_to_boolean_dyn<T>(174from: &dyn Array,175to_type: ArrowDataType,176) -> PolarsResult<Box<dyn Array>>177where178T: NativeType,179{180let from = from.as_any().downcast_ref().unwrap();181Ok(Box::new(primitive_to_boolean::<T>(from, to_type)))182}183184/// Returns a [`Utf8Array`] where every element is the utf8 representation of the number.185pub(super) fn primitive_to_utf8<T: NativeType + SerPrimitive, O: Offset>(186from: &PrimitiveArray<T>,187) -> Utf8Array<O> {188let (values, offsets) = primitive_to_values_and_offsets(from);189unsafe {190Utf8Array::<O>::new_unchecked(191Utf8Array::<O>::default_dtype(),192offsets.into(),193values.into(),194from.validity().cloned(),195)196}197}198199pub(super) fn primitive_to_utf8_dyn<T, O>(from: &dyn Array) -> PolarsResult<Box<dyn Array>>200where201O: Offset,202T: NativeType + SerPrimitive,203{204let from = from.as_any().downcast_ref().unwrap();205Ok(Box::new(primitive_to_utf8::<T, O>(from)))206}207208pub(super) fn primitive_to_primitive_dyn<I, O>(209from: &dyn Array,210to_type: &ArrowDataType,211options: CastOptionsImpl,212) -> PolarsResult<Box<dyn Array>>213where214I: NativeType + num_traits::NumCast + num_traits::AsPrimitive<O>,215O: NativeType + num_traits::NumCast,216{217let from = from.as_any().downcast_ref::<PrimitiveArray<I>>().unwrap();218if options.wrapped {219Ok(Box::new(primitive_as_primitive::<I, O>(from, to_type)))220} else {221Ok(Box::new(primitive_to_primitive::<I, O>(from, to_type)))222}223}224225/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of another physical type via numeric conversion.226pub fn primitive_to_primitive<I, O>(227from: &PrimitiveArray<I>,228to_type: &ArrowDataType,229) -> PrimitiveArray<O>230where231I: NativeType + num_traits::NumCast,232O: NativeType + num_traits::NumCast,233{234let iter = from235.iter()236.map(|v| v.and_then(|x| num_traits::cast::cast::<I, O>(*x)));237PrimitiveArray::<O>::from_trusted_len_iter(iter).to(to_type.clone())238}239240/// Returns a [`PrimitiveArray<i128>`] with the cast values. Values are `None` on overflow241#[cfg(feature = "dtype-decimal")]242pub fn integer_to_decimal<T: NativeType + ToPrimitive>(243from: &PrimitiveArray<T>,244to_precision: usize,245to_scale: usize,246) -> PrimitiveArray<i128> {247assert!(dec128_verify_prec_scale(to_precision, to_scale).is_ok());248let values = from249.iter()250.map(|x| i128_to_dec128(x?.to_i128()?, to_precision, to_scale));251PrimitiveArray::<i128>::from_trusted_len_iter(values)252.to(ArrowDataType::Decimal(to_precision, to_scale))253}254255#[cfg(feature = "dtype-decimal")]256pub(super) fn integer_to_decimal_dyn<T>(257from: &dyn Array,258precision: usize,259scale: usize,260) -> PolarsResult<Box<dyn Array>>261where262T: NativeType + ToPrimitive,263{264let from = from.as_any().downcast_ref().unwrap();265Ok(Box::new(integer_to_decimal::<T>(from, precision, scale)))266}267268/// Returns a [`PrimitiveArray<i128>`] with the cast values. Values are `None` on overflow269#[cfg(feature = "dtype-decimal")]270pub fn float_to_decimal<T: NativeType + Float + AsPrimitive<f64>>(271from: &PrimitiveArray<T>,272to_precision: usize,273to_scale: usize,274) -> PrimitiveArray<i128> {275assert!(dec128_verify_prec_scale(to_precision, to_scale).is_ok());276let values = from277.iter()278.map(|x| f64_to_dec128(x?.as_(), to_precision, to_scale));279PrimitiveArray::<i128>::from_trusted_len_iter(values)280.to(ArrowDataType::Decimal(to_precision, to_scale))281}282283#[cfg(feature = "dtype-decimal")]284pub(super) fn float_to_decimal_dyn<T: NativeType + Float + AsPrimitive<f64>>(285from: &dyn Array,286precision: usize,287scale: usize,288) -> PolarsResult<Box<dyn Array>> {289let from = from.as_any().downcast_ref().unwrap();290Ok(Box::new(float_to_decimal::<T>(from, precision, scale)))291}292293/// Cast [`PrimitiveArray`] as a [`PrimitiveArray`]294/// Same as `number as to_number_type` in rust295pub fn primitive_as_primitive<I, O>(296from: &PrimitiveArray<I>,297to_type: &ArrowDataType,298) -> PrimitiveArray<O>299where300I: NativeType + num_traits::AsPrimitive<O>,301O: NativeType,302{303unary(from, num_traits::AsPrimitive::<O>::as_, to_type.clone())304}305306/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of the same physical type.307/// This is O(1).308pub fn primitive_to_same_primitive<T>(309from: &PrimitiveArray<T>,310to_type: &ArrowDataType,311) -> PrimitiveArray<T>312where313T: NativeType,314{315PrimitiveArray::<T>::new(316to_type.clone(),317from.values().clone(),318from.validity().cloned(),319)320}321322/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of the same physical type.323/// This is O(1).324pub(super) fn primitive_to_same_primitive_dyn<T>(325from: &dyn Array,326to_type: &ArrowDataType,327) -> PolarsResult<Box<dyn Array>>328where329T: NativeType,330{331let from = from.as_any().downcast_ref().unwrap();332Ok(Box::new(primitive_to_same_primitive::<T>(from, to_type)))333}334335pub(super) fn primitive_to_dictionary_dyn<T: NativeType + Eq + Hash, K: DictionaryKey>(336from: &dyn Array,337) -> PolarsResult<Box<dyn Array>> {338let from = from.as_any().downcast_ref().unwrap();339primitive_to_dictionary::<T, K>(from).map(|x| Box::new(x) as Box<dyn Array>)340}341342/// Cast [`PrimitiveArray`] to [`DictionaryArray`]. Also known as packing.343/// # Errors344/// This function errors if the maximum key is smaller than the number of distinct elements345/// in the array.346pub fn primitive_to_dictionary<T: NativeType + Eq + Hash, K: DictionaryKey>(347from: &PrimitiveArray<T>,348) -> PolarsResult<DictionaryArray<K>> {349let iter = from.iter().map(|x| x.copied());350let mut array = MutableDictionaryArray::<K, _>::try_empty(MutablePrimitiveArray::<T>::from(351from.dtype().clone(),352))?;353array.reserve(from.len());354array.try_extend(iter)?;355356Ok(array.into())357}358359/// # Safety360///361/// `dtype` should be valid for primitive.362pub unsafe fn primitive_map_is_valid<T: NativeType>(363from: &PrimitiveArray<T>,364f: impl Fn(T) -> bool,365dtype: ArrowDataType,366) -> PrimitiveArray<T> {367let values = from.values().clone();368369let validity: Bitmap = values.iter().map(|&v| f(v)).collect();370371let validity = if validity.unset_bits() > 0 {372let new_validity = match from.validity() {373None => validity,374Some(v) => v & &validity,375};376377Some(new_validity)378} else {379from.validity().cloned()380};381382// SAFETY:383// - Validity did not change length384// - dtype should be valid385unsafe { PrimitiveArray::new_unchecked(dtype, values, validity) }386}387388/// Conversion of `Int32` to `Time32(TimeUnit::Second)`389pub fn int32_to_time32s(from: &PrimitiveArray<i32>) -> PrimitiveArray<i32> {390// SAFETY: Time32(TimeUnit::Second) is valid for Int32391unsafe {392primitive_map_is_valid(393from,394|v| (0..SECONDS_IN_DAY as i32).contains(&v),395ArrowDataType::Time32(TimeUnit::Second),396)397}398}399400/// Conversion of `Int32` to `Time32(TimeUnit::Millisecond)`401pub fn int32_to_time32ms(from: &PrimitiveArray<i32>) -> PrimitiveArray<i32> {402// SAFETY: Time32(TimeUnit::Millisecond) is valid for Int32403unsafe {404primitive_map_is_valid(405from,406|v| (0..MILLISECONDS_IN_DAY as i32).contains(&v),407ArrowDataType::Time32(TimeUnit::Millisecond),408)409}410}411412/// Conversion of `Int64` to `Time32(TimeUnit::Microsecond)`413pub fn int64_to_time64us(from: &PrimitiveArray<i64>) -> PrimitiveArray<i64> {414// SAFETY: Time64(TimeUnit::Microsecond) is valid for Int64415unsafe {416primitive_map_is_valid(417from,418|v| (0..MICROSECONDS_IN_DAY).contains(&v),419ArrowDataType::Time32(TimeUnit::Microsecond),420)421}422}423424/// Conversion of `Int64` to `Time32(TimeUnit::Nanosecond)`425pub fn int64_to_time64ns(from: &PrimitiveArray<i64>) -> PrimitiveArray<i64> {426// SAFETY: Time64(TimeUnit::Nanosecond) is valid for Int64427unsafe {428primitive_map_is_valid(429from,430|v| (0..NANOSECONDS_IN_DAY).contains(&v),431ArrowDataType::Time64(TimeUnit::Nanosecond),432)433}434}435436/// Conversion of dates437pub fn date32_to_date64(from: &PrimitiveArray<i32>) -> PrimitiveArray<i64> {438unary(439from,440|x| x as i64 * MILLISECONDS_IN_DAY,441ArrowDataType::Date64,442)443}444445/// Conversion of dates446pub fn date64_to_date32(from: &PrimitiveArray<i64>) -> PrimitiveArray<i32> {447unary(448from,449|x| (x / MILLISECONDS_IN_DAY) as i32,450ArrowDataType::Date32,451)452}453454/// Conversion of times455pub fn time32s_to_time32ms(from: &PrimitiveArray<i32>) -> PrimitiveArray<i32> {456fallible_unary(457from,458|x| x.wrapping_mul(1000),459|x| x.checked_mul(1000).is_none(),460ArrowDataType::Time32(TimeUnit::Millisecond),461)462}463464/// Conversion of times465pub fn time32ms_to_time32s(from: &PrimitiveArray<i32>) -> PrimitiveArray<i32> {466unary(from, |x| x / 1000, ArrowDataType::Time32(TimeUnit::Second))467}468469/// Conversion of times470pub fn time64us_to_time64ns(from: &PrimitiveArray<i64>) -> PrimitiveArray<i64> {471fallible_unary(472from,473|x| x.wrapping_mul(1000),474|x| x.checked_mul(1000).is_none(),475ArrowDataType::Time64(TimeUnit::Nanosecond),476)477}478479/// Conversion of times480pub fn time64ns_to_time64us(from: &PrimitiveArray<i64>) -> PrimitiveArray<i64> {481unary(482from,483|x| x / 1000,484ArrowDataType::Time64(TimeUnit::Microsecond),485)486}487488/// Conversion of timestamp489pub fn timestamp_to_date64(from: &PrimitiveArray<i64>, from_unit: TimeUnit) -> PrimitiveArray<i64> {490let from_size = time_unit_multiple(from_unit);491let to_size = MILLISECONDS;492let to_type = ArrowDataType::Date64;493494// Scale time_array by (to_size / from_size) using a495// single integer operation, but need to avoid integer496// math rounding down to zero497498match to_size.cmp(&from_size) {499std::cmp::Ordering::Less => unary(from, |x| x / (from_size / to_size), to_type),500std::cmp::Ordering::Equal => primitive_to_same_primitive(from, &to_type),501std::cmp::Ordering::Greater => fallible_unary(502from,503|x| x.wrapping_mul(to_size / from_size),504|x| x.checked_mul(to_size / from_size).is_none(),505to_type,506),507}508}509510/// Conversion of timestamp511pub fn timestamp_to_date32(from: &PrimitiveArray<i64>, from_unit: TimeUnit) -> PrimitiveArray<i32> {512let from_size = time_unit_multiple(from_unit) * SECONDS_IN_DAY;513unary(from, |x| (x / from_size) as i32, ArrowDataType::Date32)514}515516/// Conversion of time517pub fn time32_to_time64(518from: &PrimitiveArray<i32>,519from_unit: TimeUnit,520to_unit: TimeUnit,521) -> PrimitiveArray<i64> {522let from_size = time_unit_multiple(from_unit);523let to_size = time_unit_multiple(to_unit);524let divisor = to_size / from_size;525fallible_unary(526from,527|x| (x as i64).wrapping_mul(divisor),528|x| (x as i64).checked_mul(divisor).is_none(),529ArrowDataType::Time64(to_unit),530)531}532533/// Conversion of time534pub fn time64_to_time32(535from: &PrimitiveArray<i64>,536from_unit: TimeUnit,537to_unit: TimeUnit,538) -> PrimitiveArray<i32> {539let from_size = time_unit_multiple(from_unit);540let to_size = time_unit_multiple(to_unit);541let divisor = from_size / to_size;542unary(543from,544|x| (x / divisor) as i32,545ArrowDataType::Time32(to_unit),546)547}548549/// Conversion of timestamp550pub fn timestamp_to_timestamp(551from: &PrimitiveArray<i64>,552from_unit: TimeUnit,553to_unit: TimeUnit,554tz: &Option<PlSmallStr>,555) -> PrimitiveArray<i64> {556let from_size = time_unit_multiple(from_unit);557let to_size = time_unit_multiple(to_unit);558let to_type = ArrowDataType::Timestamp(to_unit, tz.clone());559// we either divide or multiply, depending on size of each unit560if from_size >= to_size {561unary(from, |x| x / (from_size / to_size), to_type)562} else {563fallible_unary(564from,565|x| x.wrapping_mul(to_size / from_size),566|x| x.checked_mul(to_size / from_size).is_none(),567to_type,568)569}570}571572/// Returns a [`Utf8Array`] where every element is the utf8 representation of the number.573pub(super) fn primitive_to_binview<T: NativeType + SerPrimitive>(574from: &PrimitiveArray<T>,575) -> BinaryViewArray {576let mut mutable = MutableBinaryViewArray::with_capacity(from.len());577578let mut scratch = vec![];579for &x in from.values().iter() {580unsafe { scratch.set_len(0) };581T::write(&mut scratch, x);582mutable.push_value_ignore_validity(&scratch)583}584585mutable.freeze().with_validity(from.validity().cloned())586}587588pub(super) fn primitive_to_binview_dyn<T>(from: &dyn Array) -> BinaryViewArray589where590T: NativeType + SerPrimitive,591{592let from = from.as_any().downcast_ref().unwrap();593primitive_to_binview::<T>(from)594}595596597