Path: blob/main/crates/polars-core/src/series/ops/reshape.rs
8440 views
use std::borrow::Cow;12use arrow::array::*;3use arrow::bitmap::Bitmap;4use arrow::offset::{Offsets, OffsetsBuffer};5use polars_compute::gather::sublist::list::array_to_unit_list;6use polars_error::{PolarsResult, polars_bail, polars_ensure};7use polars_utils::format_tuple;89use crate::chunked_array::builder::get_list_builder;10use crate::datatypes::{DataType, ListChunked};11use crate::prelude::{IntoSeries, Series, *};1213fn reshape_fast_path(name: PlSmallStr, s: &Series) -> Series {14let mut ca = ListChunked::from_chunk_iter(15name,16s.chunks().iter().map(|arr| array_to_unit_list(arr.clone())),17);1819ca.set_inner_dtype(s.dtype().clone());20ca.set_fast_explode();21ca.into_series()22}2324impl Series {25/// Recurse nested types until we are at the leaf array.26pub fn get_leaf_array(&self) -> Series {27let s = self;28match s.dtype() {29#[cfg(feature = "dtype-array")]30DataType::Array(dtype, _) => {31let ca = s.array().unwrap();32let chunks = ca33.downcast_iter()34.map(|arr| arr.values().clone())35.collect::<Vec<_>>();36// Safety: guarded by the type system37unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) }38.get_leaf_array()39},40DataType::List(dtype) => {41let ca = s.list().unwrap();42let chunks = ca43.downcast_iter()44.map(|arr| arr.values().clone())45.collect::<Vec<_>>();46// Safety: guarded by the type system47unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) }48.get_leaf_array()49},50_ => s.clone(),51}52}5354/// TODO: Move this somewhere else?55pub fn list_offsets_and_validities_recursive(56&self,57) -> (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>) {58let mut offsets = vec![];59let mut validities = vec![];6061let mut s = self.rechunk();6263while let DataType::List(_) = s.dtype() {64let ca = s.list().unwrap();65offsets.push(ca.offsets().unwrap());66validities.push(ca.rechunk_validity());67s = ca.get_inner();68}6970(offsets, validities)71}7273/// Convert the values of this Series to a ListChunked with a length of 1,74/// so a Series of `[1, 2, 3]` becomes `[[1, 2, 3]]`.75pub fn implode(&self) -> PolarsResult<ListChunked> {76let s = self;77let s = s.rechunk();78let values = s.array_ref(0);7980let offsets = vec![0i64, values.len() as i64];81let inner_type = s.dtype();8283let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());8485// SAFETY: offsets are correct.86let arr = unsafe {87ListArray::new(88dtype,89Offsets::new_unchecked(offsets).into(),90values.clone(),91None,92)93};9495let mut ca = ListChunked::with_chunk(s.name().clone(), arr);96unsafe { ca.to_logical(inner_type.clone()) };97ca.set_fast_explode();98Ok(ca)99}100101#[cfg(feature = "dtype-array")]102pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {103polars_ensure!(104!dimensions.is_empty(),105InvalidOperation: "at least one dimension must be specified"106);107108let leaf_array = self109.trim_lists_to_normalized_offsets()110.as_ref()111.unwrap_or(self)112.get_leaf_array()113.rechunk();114let size = leaf_array.len();115116let mut total_dim_size = 1;117let mut num_infers = 0;118for &dim in dimensions {119match dim {120ReshapeDimension::Infer => num_infers += 1,121ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize,122}123}124125polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension");126127if size == 0 {128polars_ensure!(129num_infers > 0 || total_dim_size == 0,130InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}",131format_tuple!(dimensions),132);133134let mut prev_arrow_dtype = leaf_array135.dtype()136.to_physical()137.to_arrow(CompatLevel::newest());138let mut prev_dtype = leaf_array.dtype().clone();139let mut prev_array = leaf_array.chunks()[0].clone();140141// @NOTE: We need to collect the iterator here because it is lazily processed.142let mut current_length = dimensions[0].get_or_infer(0);143let len_iter = dimensions[1..]144.iter()145.map(|d| {146let length = current_length as usize;147current_length *= d.get_or_infer(0);148length149})150.collect::<Vec<_>>();151152// We pop the outer dimension as that is the height of the series.153for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() {154// Infer dimension if needed155let dim = dim.get_or_infer(0);156prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);157prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);158159prev_array =160FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None)161.boxed();162}163164return Ok(unsafe {165Series::from_chunks_and_dtype_unchecked(166leaf_array.name().clone(),167vec![prev_array],168&prev_dtype,169)170});171}172173polars_ensure!(174total_dim_size > 0,175InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}",176format_tuple!(dimensions)177);178179polars_ensure!(180size.is_multiple_of(total_dim_size),181InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions)182);183184let leaf_array = leaf_array.rechunk();185let mut prev_arrow_dtype = leaf_array186.dtype()187.to_physical()188.to_arrow(CompatLevel::newest());189let mut prev_dtype = leaf_array.dtype().clone();190let mut prev_array = leaf_array.chunks()[0].clone();191let inferred_size = (size / total_dim_size) as u64;192let outer_dimension = dimensions[0].get_or_infer(inferred_size);193194// We pop the outer dimension as that is the height of the series.195for dim in dimensions[1..].iter().rev() {196// Infer dimension if needed197let dim = dim.get_or_infer(inferred_size);198prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);199prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);200201prev_array = FixedSizeListArray::new(202prev_arrow_dtype.clone(),203prev_array.len() / dim as usize,204prev_array,205None,206)207.boxed();208}209210polars_ensure!(211prev_array.len() as u64 == outer_dimension,212InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions)213);214215Ok(unsafe {216Series::from_chunks_and_dtype_unchecked(217leaf_array.name().clone(),218vec![prev_array],219&prev_dtype,220)221})222}223224pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {225polars_ensure!(226!dimensions.is_empty(),227InvalidOperation: "at least one dimension must be specified"228);229230let s = self;231let s = if let DataType::List(_) = s.dtype() {232Cow::Owned(s.explode(ExplodeOptions {233empty_as_null: false,234keep_nulls: true,235})?)236} else {237Cow::Borrowed(s)238};239240let s_ref = s.as_ref();241242// let dimensions = dimensions.to_vec();243244match dimensions.len() {2451 => {246polars_ensure!(247dimensions[0].get().is_none_or( |dim| dim as usize == s_ref.len()),248InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,249);250Ok(s_ref.clone())251},2522 => {253let rows = dimensions[0];254let cols = dimensions[1];255256if s_ref.is_empty() {257if rows.get_or_infer(0) == 0 && cols.get_or_infer(0) <= 1 {258let s = reshape_fast_path(s.name().clone(), s_ref);259return Ok(s);260} else {261polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {}", format_tuple!(dimensions))262}263}264265use ReshapeDimension as RD;266// Infer dimension.267268let (rows, cols) = match (rows, cols) {269(RD::Infer, RD::Specified(cols)) if cols.get() >= 1 => {270(s_ref.len() as u64 / cols.get(), cols.get())271},272(RD::Specified(rows), RD::Infer) if rows.get() >= 1 => {273(rows.get(), s_ref.len() as u64 / rows.get())274},275(RD::Infer, RD::Infer) => (s_ref.len() as u64, 1u64),276(RD::Specified(rows), RD::Specified(cols)) => (rows.get(), cols.get()),277_ => polars_bail!(InvalidOperation: "reshape of non-zero list into zero list"),278};279280// Fast path, we can create a unit list so we only allocate offsets.281if rows as usize == s_ref.len() && cols == 1 {282let s = reshape_fast_path(s.name().clone(), s_ref);283return Ok(s);284}285286polars_ensure!(287(rows*cols) as usize == s_ref.len() && rows >= 1 && cols >= 1,288InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,289);290291let mut builder =292get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name().clone());293294let mut offset = 0u64;295for _ in 0..rows {296let row = s_ref.slice(offset as i64, cols as usize);297builder.append_series(&row).unwrap();298offset += cols;299}300Ok(builder.finish().into_series())301},302_ => {303polars_bail!(InvalidOperation: "more than two dimensions not supported in reshaping to List.\n\nConsider reshaping to Array type.");304},305}306}307}308309#[cfg(test)]310mod test {311use super::*;312use crate::prelude::*;313314#[test]315fn test_to_list() -> PolarsResult<()> {316let s = Series::new("a".into(), &[1, 2, 3]);317318let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone());319builder.append_series(&s).unwrap();320let expected = builder.finish();321322let out = s.implode()?;323assert!(expected.into_series().equals(&out.into_series()));324325Ok(())326}327328#[test]329fn test_reshape() -> PolarsResult<()> {330let s = Series::new("a".into(), &[1, 2, 3, 4]);331332for (dims, list_len) in [333(&[-1, 1], 4),334(&[4, 1], 4),335(&[2, 2], 2),336(&[-1, 2], 2),337(&[2, -1], 2),338] {339let dims = dims340.iter()341.map(|&v| ReshapeDimension::new(v))342.collect::<Vec<_>>();343let out = s.reshape_list(&dims)?;344assert_eq!(out.len(), list_len);345assert!(matches!(out.dtype(), DataType::List(_)));346assert_eq!(347out.explode(ExplodeOptions {348empty_as_null: true,349keep_nulls: true,350})?351.len(),3524353);354}355356Ok(())357}358}359360361