Path: blob/main/crates/polars-core/src/series/ops/reshape.rs
6940 views
use std::borrow::Cow;12use arrow::array::*;3use arrow::bitmap::Bitmap;4use arrow::offset::{Offsets, OffsetsBuffer};5use polars_compute::gather::sublist::list::array_to_unit_list;6use polars_error::{PolarsResult, polars_bail, polars_ensure};7use polars_utils::format_tuple;89use crate::chunked_array::builder::get_list_builder;10use crate::datatypes::{DataType, ListChunked};11use crate::prelude::{IntoSeries, Series, *};1213fn reshape_fast_path(name: PlSmallStr, s: &Series) -> Series {14let mut ca = ListChunked::from_chunk_iter(15name,16s.chunks().iter().map(|arr| array_to_unit_list(arr.clone())),17);1819ca.set_inner_dtype(s.dtype().clone());20ca.set_fast_explode();21ca.into_series()22}2324impl Series {25/// Recurse nested types until we are at the leaf array.26pub fn get_leaf_array(&self) -> Series {27let s = self;28match s.dtype() {29#[cfg(feature = "dtype-array")]30DataType::Array(dtype, _) => {31let ca = s.array().unwrap();32let chunks = ca33.downcast_iter()34.map(|arr| arr.values().clone())35.collect::<Vec<_>>();36// Safety: guarded by the type system37unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) }38.get_leaf_array()39},40DataType::List(dtype) => {41let ca = s.list().unwrap();42let chunks = ca43.downcast_iter()44.map(|arr| arr.values().clone())45.collect::<Vec<_>>();46// Safety: guarded by the type system47unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) }48.get_leaf_array()49},50_ => s.clone(),51}52}5354/// TODO: Move this somewhere else?55pub fn list_offsets_and_validities_recursive(56&self,57) -> (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>) {58let mut offsets = vec![];59let mut validities = vec![];6061let mut s = self.rechunk();6263while let DataType::List(_) = s.dtype() {64let ca = s.list().unwrap();65offsets.push(ca.offsets().unwrap());66validities.push(ca.rechunk_validity());67s = ca.get_inner();68}6970(offsets, validities)71}7273/// Convert the values of this Series to a ListChunked with a length of 1,74/// so a Series of `[1, 2, 3]` becomes `[[1, 2, 3]]`.75pub fn implode(&self) -> PolarsResult<ListChunked> {76let s = self;77let s = s.rechunk();78let values = s.array_ref(0);7980let offsets = vec![0i64, values.len() as i64];81let inner_type = s.dtype();8283let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());8485// SAFETY: offsets are correct.86let arr = unsafe {87ListArray::new(88dtype,89Offsets::new_unchecked(offsets).into(),90values.clone(),91None,92)93};9495let mut ca = ListChunked::with_chunk(s.name().clone(), arr);96unsafe { ca.to_logical(inner_type.clone()) };97ca.set_fast_explode();98Ok(ca)99}100101#[cfg(feature = "dtype-array")]102pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {103polars_ensure!(104!dimensions.is_empty(),105InvalidOperation: "at least one dimension must be specified"106);107108let leaf_array = self.get_leaf_array().rechunk();109let size = leaf_array.len();110111let mut total_dim_size = 1;112let mut num_infers = 0;113for &dim in dimensions {114match dim {115ReshapeDimension::Infer => num_infers += 1,116ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize,117}118}119120polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension");121122if size == 0 {123polars_ensure!(124num_infers > 0 || total_dim_size == 0,125InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}",126format_tuple!(dimensions),127);128129let mut prev_arrow_dtype = leaf_array130.dtype()131.to_physical()132.to_arrow(CompatLevel::newest());133let mut prev_dtype = leaf_array.dtype().clone();134let mut prev_array = leaf_array.chunks()[0].clone();135136// @NOTE: We need to collect the iterator here because it is lazily processed.137let mut current_length = dimensions[0].get_or_infer(0);138let len_iter = dimensions[1..]139.iter()140.map(|d| {141let length = current_length as usize;142current_length *= d.get_or_infer(0);143length144})145.collect::<Vec<_>>();146147// We pop the outer dimension as that is the height of the series.148for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() {149// Infer dimension if needed150let dim = dim.get_or_infer(0);151prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);152prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);153154prev_array =155FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None)156.boxed();157}158159return Ok(unsafe {160Series::from_chunks_and_dtype_unchecked(161leaf_array.name().clone(),162vec![prev_array],163&prev_dtype,164)165});166}167168polars_ensure!(169total_dim_size > 0,170InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}",171format_tuple!(dimensions)172);173174polars_ensure!(175size.is_multiple_of(total_dim_size),176InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions)177);178179let leaf_array = leaf_array.rechunk();180let mut prev_arrow_dtype = leaf_array181.dtype()182.to_physical()183.to_arrow(CompatLevel::newest());184let mut prev_dtype = leaf_array.dtype().clone();185let mut prev_array = leaf_array.chunks()[0].clone();186187// We pop the outer dimension as that is the height of the series.188for dim in dimensions[1..].iter().rev() {189// Infer dimension if needed190let dim = dim.get_or_infer((size / total_dim_size) as u64);191prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);192prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);193194prev_array = FixedSizeListArray::new(195prev_arrow_dtype.clone(),196prev_array.len() / dim as usize,197prev_array,198None,199)200.boxed();201}202Ok(unsafe {203Series::from_chunks_and_dtype_unchecked(204leaf_array.name().clone(),205vec![prev_array],206&prev_dtype,207)208})209}210211pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {212polars_ensure!(213!dimensions.is_empty(),214InvalidOperation: "at least one dimension must be specified"215);216217let s = self;218let s = if let DataType::List(_) = s.dtype() {219Cow::Owned(s.explode(true)?)220} else {221Cow::Borrowed(s)222};223224let s_ref = s.as_ref();225226// let dimensions = dimensions.to_vec();227228match dimensions.len() {2291 => {230polars_ensure!(231dimensions[0].get().is_none_or( |dim| dim as usize == s_ref.len()),232InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,233);234Ok(s_ref.clone())235},2362 => {237let rows = dimensions[0];238let cols = dimensions[1];239240if s_ref.is_empty() {241if rows.get_or_infer(0) == 0 && cols.get_or_infer(0) <= 1 {242let s = reshape_fast_path(s.name().clone(), s_ref);243return Ok(s);244} else {245polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {}", format_tuple!(dimensions))246}247}248249use ReshapeDimension as RD;250// Infer dimension.251252let (rows, cols) = match (rows, cols) {253(RD::Infer, RD::Specified(cols)) if cols.get() >= 1 => {254(s_ref.len() as u64 / cols.get(), cols.get())255},256(RD::Specified(rows), RD::Infer) if rows.get() >= 1 => {257(rows.get(), s_ref.len() as u64 / rows.get())258},259(RD::Infer, RD::Infer) => (s_ref.len() as u64, 1u64),260(RD::Specified(rows), RD::Specified(cols)) => (rows.get(), cols.get()),261_ => polars_bail!(InvalidOperation: "reshape of non-zero list into zero list"),262};263264// Fast path, we can create a unit list so we only allocate offsets.265if rows as usize == s_ref.len() && cols == 1 {266let s = reshape_fast_path(s.name().clone(), s_ref);267return Ok(s);268}269270polars_ensure!(271(rows*cols) as usize == s_ref.len() && rows >= 1 && cols >= 1,272InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,273);274275let mut builder =276get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name().clone());277278let mut offset = 0u64;279for _ in 0..rows {280let row = s_ref.slice(offset as i64, cols as usize);281builder.append_series(&row).unwrap();282offset += cols;283}284Ok(builder.finish().into_series())285},286_ => {287polars_bail!(InvalidOperation: "more than two dimensions not supported in reshaping to List.\n\nConsider reshaping to Array type.");288},289}290}291}292293#[cfg(test)]294mod test {295use super::*;296use crate::prelude::*;297298#[test]299fn test_to_list() -> PolarsResult<()> {300let s = Series::new("a".into(), &[1, 2, 3]);301302let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone());303builder.append_series(&s).unwrap();304let expected = builder.finish();305306let out = s.implode()?;307assert!(expected.into_series().equals(&out.into_series()));308309Ok(())310}311312#[test]313fn test_reshape() -> PolarsResult<()> {314let s = Series::new("a".into(), &[1, 2, 3, 4]);315316for (dims, list_len) in [317(&[-1, 1], 4),318(&[4, 1], 4),319(&[2, 2], 2),320(&[-1, 2], 2),321(&[2, -1], 2),322] {323let dims = dims324.iter()325.map(|&v| ReshapeDimension::new(v))326.collect::<Vec<_>>();327let out = s.reshape_list(&dims)?;328assert_eq!(out.len(), list_len);329assert!(matches!(out.dtype(), DataType::List(_)));330assert_eq!(out.explode(false)?.len(), 4);331}332333Ok(())334}335}336337338