Path: blob/main/crates/polars-ops/src/chunked_array/repeat_by.rs
6939 views
use arrow::array::builder::{ArrayBuilder, ShareStrategy, make_builder};1use arrow::array::{Array, IntoBoxedArray, ListArray, NullArray};2use arrow::bitmap::BitmapBuilder;3use arrow::offset::Offsets;4use arrow::pushable::Pushable;5use polars_core::prelude::*;6use polars_core::with_match_physical_numeric_polars_type;78type LargeListArray = ListArray<i64>;910fn check_lengths(length_srs: usize, length_by: usize) -> PolarsResult<()> {11polars_ensure!(12(length_srs == length_by) | (length_by == 1) | (length_srs == 1),13ShapeMismatch: "repeat_by argument and the Series should have equal length, or at least one of them should have length 1. Series length {}, by length {}",14length_srs, length_by15);16Ok(())17}1819fn new_by(by: &IdxCa, len: usize) -> IdxCa {20if let Some(x) = by.get(0) {21let values = std::iter::repeat_n(x, len).collect::<Vec<IdxSize>>();22IdxCa::new(PlSmallStr::EMPTY, values)23} else {24IdxCa::full_null(PlSmallStr::EMPTY, len)25}26}2728fn repeat_by_primitive<T>(ca: &ChunkedArray<T>, by: &IdxCa) -> PolarsResult<ListChunked>29where30T: PolarsNumericType,31{32check_lengths(ca.len(), by.len())?;3334match (ca.len(), by.len()) {35(left_len, right_len) if left_len == right_len => {36Ok(arity::binary(ca, by, |arr, by| {37let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {38opt_by.map(|by| std::iter::repeat_n(opt_v.copied(), *by as usize))39});4041// SAFETY: length of iter is trusted.42unsafe {43LargeListArray::from_iter_primitive_trusted_len(44iter,45T::get_static_dtype().to_arrow(CompatLevel::newest()),46)47}48}))49},50(_, 1) => {51let by = new_by(by, ca.len());52repeat_by_primitive(ca, &by)53},54(1, _) => {55let new_array = ca.new_from_index(0, by.len());56repeat_by_primitive(&new_array, by)57},58// we have already checked the length59_ => unreachable!(),60}61}6263fn repeat_by_bool(ca: &BooleanChunked, by: &IdxCa) -> PolarsResult<ListChunked> {64check_lengths(ca.len(), by.len())?;6566match (ca.len(), by.len()) {67(left_len, right_len) if left_len == right_len => {68Ok(arity::binary(ca, by, |arr, by| {69let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {70opt_by.map(|by| std::iter::repeat_n(opt_v, *by as usize))71});7273// SAFETY: length of iter is trusted.74unsafe { LargeListArray::from_iter_bool_trusted_len(iter) }75}))76},77(_, 1) => {78let by = new_by(by, ca.len());79repeat_by_bool(ca, &by)80},81(1, _) => {82let new_array = ca.new_from_index(0, by.len());83repeat_by_bool(&new_array, by)84},85// we have already checked the length86_ => unreachable!(),87}88}8990fn repeat_by_binary(ca: &BinaryChunked, by: &IdxCa) -> PolarsResult<ListChunked> {91check_lengths(ca.len(), by.len())?;9293match (ca.len(), by.len()) {94(left_len, right_len) if left_len == right_len => {95Ok(arity::binary(ca, by, |arr, by| {96let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {97opt_by.map(|by| std::iter::repeat_n(opt_v, *by as usize))98});99100// SAFETY: length of iter is trusted.101unsafe { LargeListArray::from_iter_binary_trusted_len(iter, ca.len()) }102}))103},104(_, 1) => {105let by = new_by(by, ca.len());106repeat_by_binary(ca, &by)107},108(1, _) => {109let new_array = ca.new_from_index(0, by.len());110repeat_by_binary(&new_array, by)111},112// we have already checked the length113_ => unreachable!(),114}115}116117fn repeat_by_list(ca: &ListChunked, by: &IdxCa) -> PolarsResult<ListChunked> {118check_lengths(ca.len(), by.len())?;119120match (ca.len(), by.len()) {121(left_len, right_len) if left_len == right_len => Ok(repeat_by_generic_inner(ca, by)),122(_, 1) => {123let by = new_by(by, ca.len());124repeat_by_list(ca, &by)125},126(1, _) => {127let new_array = ca.new_from_index(0, by.len());128repeat_by_list(&new_array, by)129},130// we have already checked the length131_ => unreachable!(),132}133}134135fn repeat_by_null(ca: &NullChunked, by: &IdxCa) -> PolarsResult<ListChunked> {136check_lengths(ca.len(), by.len())?;137138match (ca.len(), by.len()) {139(left_len, right_len) if left_len == right_len => {140let arr_length = by.iter().flatten().map(|x| x as usize).sum();141let arr = NullArray::new(ArrowDataType::Null, arr_length);142143let mut validity = BitmapBuilder::with_capacity(by.len());144let mut offsets = Offsets::<i64>::with_capacity(by.len());145for n_repeat in by.iter() {146validity.push(n_repeat.is_some());147if let Some(repeats) = n_repeat {148offsets.push(repeats as usize);149} else {150offsets.push_null();151}152}153154let array = LargeListArray::new(155ListArray::<i64>::default_datatype(arr.dtype().clone()),156offsets.into(),157arr.into_boxed(),158validity.into_opt_validity(),159);160161Ok(unsafe {162ListChunked::from_chunks_and_dtype(163ca.name().clone(),164vec![array.into_boxed()],165DataType::List(Box::new(DataType::Null)),166)167})168},169(_, 1) => {170let by = new_by(by, ca.len());171repeat_by_null(ca, &by)172},173(1, _) => {174let new_array = ca.new_from_index(0, by.len());175let new_array = new_array.null().unwrap();176repeat_by_null(new_array, by)177},178// we have already checked the length179_ => unreachable!(),180}181}182183#[cfg(feature = "dtype-array")]184fn repeat_by_array(ca: &ArrayChunked, by: &IdxCa) -> PolarsResult<ListChunked> {185check_lengths(ca.len(), by.len())?;186187match (ca.len(), by.len()) {188(left_len, right_len) if left_len == right_len => Ok(repeat_by_generic_inner(ca, by)),189(_, 1) => {190let by = new_by(by, ca.len());191repeat_by_array(ca, &by)192},193(1, _) => {194let new_array = ca.new_from_index(0, by.len());195repeat_by_array(&new_array, by)196},197// we have already checked the length198_ => unreachable!(),199}200}201202#[cfg(feature = "dtype-struct")]203fn repeat_by_struct(ca: &StructChunked, by: &IdxCa) -> PolarsResult<ListChunked> {204check_lengths(ca.len(), by.len())?;205206match (ca.len(), by.len()) {207(left_len, right_len) if left_len == right_len => Ok(repeat_by_generic_inner(ca, by)),208(_, 1) => {209let by = new_by(by, ca.len());210repeat_by_struct(ca, &by)211},212(1, _) => {213let new_array = ca.new_from_index(0, by.len());214repeat_by_struct(&new_array, by)215},216// we have already checked the length217_ => unreachable!(),218}219}220221fn repeat_by_generic_inner<T: PolarsDataType>(ca: &ChunkedArray<T>, by: &IdxCa) -> ListChunked {222let mut builder = make_builder(&ca.dtype().to_arrow(CompatLevel::newest()));223arity::binary(ca, by, |arr, by| {224let arr_length = by.iter().flatten().map(|x| *x as usize).sum();225builder.reserve(arr_length);226227let mut validity = BitmapBuilder::with_capacity(by.len());228let mut offsets = Offsets::<i64>::with_capacity(by.len());229for (idx, n_repeat) in by.iter().enumerate() {230validity.push(n_repeat.is_some());231if let Some(repeats) = n_repeat {232offsets.push(*repeats as usize);233builder.subslice_extend_repeated(234arr,235idx,2361,237*repeats as usize,238ShareStrategy::Always,239);240} else {241offsets.push_null();242}243}244245let repeated_values = builder.freeze_reset();246LargeListArray::new(247ListArray::<i64>::default_datatype(arr.dtype().clone()),248offsets.into(),249repeated_values,250validity.into_opt_validity(),251)252})253}254255pub fn repeat_by(s: &Series, by: &IdxCa) -> PolarsResult<ListChunked> {256let s_phys = s.to_physical_repr();257use DataType as D;258let out = match s_phys.dtype() {259D::Null => repeat_by_null(s_phys.null().unwrap(), by),260D::Boolean => repeat_by_bool(s_phys.bool().unwrap(), by),261D::String => {262let ca = s_phys.str().unwrap();263repeat_by_binary(&ca.as_binary(), by)264.and_then(|ca| ca.apply_to_inner(&|s| unsafe { s.cast_unchecked(&D::String) }))265},266D::Binary => repeat_by_binary(s_phys.binary().unwrap(), by),267dt if dt.is_primitive_numeric() => {268with_match_physical_numeric_polars_type!(dt, |$T| {269let ca: &ChunkedArray<$T> = s_phys.as_ref().as_ref().as_ref();270repeat_by_primitive(ca, by)271})272},273D::List(_) => repeat_by_list(s_phys.list().unwrap(), by),274#[cfg(feature = "dtype-struct")]275D::Struct(_) => repeat_by_struct(s_phys.struct_().unwrap(), by),276#[cfg(feature = "dtype-array")]277D::Array(_, _) => repeat_by_array(s_phys.array().unwrap(), by),278_ => polars_bail!(opq = repeat_by, s.dtype()),279};280out.and_then(|ca| {281let logical_type = s.dtype();282ca.apply_to_inner(&|s| unsafe { s.from_physical_unchecked(logical_type) })283})284}285286287