Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-core/src/series/ops/reshape.rs
8440 views
1
use std::borrow::Cow;
2
3
use arrow::array::*;
4
use arrow::bitmap::Bitmap;
5
use arrow::offset::{Offsets, OffsetsBuffer};
6
use polars_compute::gather::sublist::list::array_to_unit_list;
7
use polars_error::{PolarsResult, polars_bail, polars_ensure};
8
use polars_utils::format_tuple;
9
10
use crate::chunked_array::builder::get_list_builder;
11
use crate::datatypes::{DataType, ListChunked};
12
use crate::prelude::{IntoSeries, Series, *};
13
14
fn reshape_fast_path(name: PlSmallStr, s: &Series) -> Series {
15
let mut ca = ListChunked::from_chunk_iter(
16
name,
17
s.chunks().iter().map(|arr| array_to_unit_list(arr.clone())),
18
);
19
20
ca.set_inner_dtype(s.dtype().clone());
21
ca.set_fast_explode();
22
ca.into_series()
23
}
24
25
impl Series {
26
/// Recurse nested types until we are at the leaf array.
27
pub fn get_leaf_array(&self) -> Series {
28
let s = self;
29
match s.dtype() {
30
#[cfg(feature = "dtype-array")]
31
DataType::Array(dtype, _) => {
32
let ca = s.array().unwrap();
33
let chunks = ca
34
.downcast_iter()
35
.map(|arr| arr.values().clone())
36
.collect::<Vec<_>>();
37
// Safety: guarded by the type system
38
unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) }
39
.get_leaf_array()
40
},
41
DataType::List(dtype) => {
42
let ca = s.list().unwrap();
43
let chunks = ca
44
.downcast_iter()
45
.map(|arr| arr.values().clone())
46
.collect::<Vec<_>>();
47
// Safety: guarded by the type system
48
unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) }
49
.get_leaf_array()
50
},
51
_ => s.clone(),
52
}
53
}
54
55
/// TODO: Move this somewhere else?
56
pub fn list_offsets_and_validities_recursive(
57
&self,
58
) -> (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>) {
59
let mut offsets = vec![];
60
let mut validities = vec![];
61
62
let mut s = self.rechunk();
63
64
while let DataType::List(_) = s.dtype() {
65
let ca = s.list().unwrap();
66
offsets.push(ca.offsets().unwrap());
67
validities.push(ca.rechunk_validity());
68
s = ca.get_inner();
69
}
70
71
(offsets, validities)
72
}
73
74
/// Convert the values of this Series to a ListChunked with a length of 1,
75
/// so a Series of `[1, 2, 3]` becomes `[[1, 2, 3]]`.
76
pub fn implode(&self) -> PolarsResult<ListChunked> {
77
let s = self;
78
let s = s.rechunk();
79
let values = s.array_ref(0);
80
81
let offsets = vec![0i64, values.len() as i64];
82
let inner_type = s.dtype();
83
84
let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());
85
86
// SAFETY: offsets are correct.
87
let arr = unsafe {
88
ListArray::new(
89
dtype,
90
Offsets::new_unchecked(offsets).into(),
91
values.clone(),
92
None,
93
)
94
};
95
96
let mut ca = ListChunked::with_chunk(s.name().clone(), arr);
97
unsafe { ca.to_logical(inner_type.clone()) };
98
ca.set_fast_explode();
99
Ok(ca)
100
}
101
102
#[cfg(feature = "dtype-array")]
103
pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {
104
polars_ensure!(
105
!dimensions.is_empty(),
106
InvalidOperation: "at least one dimension must be specified"
107
);
108
109
let leaf_array = self
110
.trim_lists_to_normalized_offsets()
111
.as_ref()
112
.unwrap_or(self)
113
.get_leaf_array()
114
.rechunk();
115
let size = leaf_array.len();
116
117
let mut total_dim_size = 1;
118
let mut num_infers = 0;
119
for &dim in dimensions {
120
match dim {
121
ReshapeDimension::Infer => num_infers += 1,
122
ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize,
123
}
124
}
125
126
polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension");
127
128
if size == 0 {
129
polars_ensure!(
130
num_infers > 0 || total_dim_size == 0,
131
InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}",
132
format_tuple!(dimensions),
133
);
134
135
let mut prev_arrow_dtype = leaf_array
136
.dtype()
137
.to_physical()
138
.to_arrow(CompatLevel::newest());
139
let mut prev_dtype = leaf_array.dtype().clone();
140
let mut prev_array = leaf_array.chunks()[0].clone();
141
142
// @NOTE: We need to collect the iterator here because it is lazily processed.
143
let mut current_length = dimensions[0].get_or_infer(0);
144
let len_iter = dimensions[1..]
145
.iter()
146
.map(|d| {
147
let length = current_length as usize;
148
current_length *= d.get_or_infer(0);
149
length
150
})
151
.collect::<Vec<_>>();
152
153
// We pop the outer dimension as that is the height of the series.
154
for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() {
155
// Infer dimension if needed
156
let dim = dim.get_or_infer(0);
157
prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);
158
prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);
159
160
prev_array =
161
FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None)
162
.boxed();
163
}
164
165
return Ok(unsafe {
166
Series::from_chunks_and_dtype_unchecked(
167
leaf_array.name().clone(),
168
vec![prev_array],
169
&prev_dtype,
170
)
171
});
172
}
173
174
polars_ensure!(
175
total_dim_size > 0,
176
InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}",
177
format_tuple!(dimensions)
178
);
179
180
polars_ensure!(
181
size.is_multiple_of(total_dim_size),
182
InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions)
183
);
184
185
let leaf_array = leaf_array.rechunk();
186
let mut prev_arrow_dtype = leaf_array
187
.dtype()
188
.to_physical()
189
.to_arrow(CompatLevel::newest());
190
let mut prev_dtype = leaf_array.dtype().clone();
191
let mut prev_array = leaf_array.chunks()[0].clone();
192
let inferred_size = (size / total_dim_size) as u64;
193
let outer_dimension = dimensions[0].get_or_infer(inferred_size);
194
195
// We pop the outer dimension as that is the height of the series.
196
for dim in dimensions[1..].iter().rev() {
197
// Infer dimension if needed
198
let dim = dim.get_or_infer(inferred_size);
199
prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);
200
prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);
201
202
prev_array = FixedSizeListArray::new(
203
prev_arrow_dtype.clone(),
204
prev_array.len() / dim as usize,
205
prev_array,
206
None,
207
)
208
.boxed();
209
}
210
211
polars_ensure!(
212
prev_array.len() as u64 == outer_dimension,
213
InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions)
214
);
215
216
Ok(unsafe {
217
Series::from_chunks_and_dtype_unchecked(
218
leaf_array.name().clone(),
219
vec![prev_array],
220
&prev_dtype,
221
)
222
})
223
}
224
225
pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {
226
polars_ensure!(
227
!dimensions.is_empty(),
228
InvalidOperation: "at least one dimension must be specified"
229
);
230
231
let s = self;
232
let s = if let DataType::List(_) = s.dtype() {
233
Cow::Owned(s.explode(ExplodeOptions {
234
empty_as_null: false,
235
keep_nulls: true,
236
})?)
237
} else {
238
Cow::Borrowed(s)
239
};
240
241
let s_ref = s.as_ref();
242
243
// let dimensions = dimensions.to_vec();
244
245
match dimensions.len() {
246
1 => {
247
polars_ensure!(
248
dimensions[0].get().is_none_or( |dim| dim as usize == s_ref.len()),
249
InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,
250
);
251
Ok(s_ref.clone())
252
},
253
2 => {
254
let rows = dimensions[0];
255
let cols = dimensions[1];
256
257
if s_ref.is_empty() {
258
if rows.get_or_infer(0) == 0 && cols.get_or_infer(0) <= 1 {
259
let s = reshape_fast_path(s.name().clone(), s_ref);
260
return Ok(s);
261
} else {
262
polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {}", format_tuple!(dimensions))
263
}
264
}
265
266
use ReshapeDimension as RD;
267
// Infer dimension.
268
269
let (rows, cols) = match (rows, cols) {
270
(RD::Infer, RD::Specified(cols)) if cols.get() >= 1 => {
271
(s_ref.len() as u64 / cols.get(), cols.get())
272
},
273
(RD::Specified(rows), RD::Infer) if rows.get() >= 1 => {
274
(rows.get(), s_ref.len() as u64 / rows.get())
275
},
276
(RD::Infer, RD::Infer) => (s_ref.len() as u64, 1u64),
277
(RD::Specified(rows), RD::Specified(cols)) => (rows.get(), cols.get()),
278
_ => polars_bail!(InvalidOperation: "reshape of non-zero list into zero list"),
279
};
280
281
// Fast path, we can create a unit list so we only allocate offsets.
282
if rows as usize == s_ref.len() && cols == 1 {
283
let s = reshape_fast_path(s.name().clone(), s_ref);
284
return Ok(s);
285
}
286
287
polars_ensure!(
288
(rows*cols) as usize == s_ref.len() && rows >= 1 && cols >= 1,
289
InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,
290
);
291
292
let mut builder =
293
get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name().clone());
294
295
let mut offset = 0u64;
296
for _ in 0..rows {
297
let row = s_ref.slice(offset as i64, cols as usize);
298
builder.append_series(&row).unwrap();
299
offset += cols;
300
}
301
Ok(builder.finish().into_series())
302
},
303
_ => {
304
polars_bail!(InvalidOperation: "more than two dimensions not supported in reshaping to List.\n\nConsider reshaping to Array type.");
305
},
306
}
307
}
308
}
309
310
#[cfg(test)]
311
mod test {
312
use super::*;
313
use crate::prelude::*;
314
315
#[test]
316
fn test_to_list() -> PolarsResult<()> {
317
let s = Series::new("a".into(), &[1, 2, 3]);
318
319
let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone());
320
builder.append_series(&s).unwrap();
321
let expected = builder.finish();
322
323
let out = s.implode()?;
324
assert!(expected.into_series().equals(&out.into_series()));
325
326
Ok(())
327
}
328
329
#[test]
330
fn test_reshape() -> PolarsResult<()> {
331
let s = Series::new("a".into(), &[1, 2, 3, 4]);
332
333
for (dims, list_len) in [
334
(&[-1, 1], 4),
335
(&[4, 1], 4),
336
(&[2, 2], 2),
337
(&[-1, 2], 2),
338
(&[2, -1], 2),
339
] {
340
let dims = dims
341
.iter()
342
.map(|&v| ReshapeDimension::new(v))
343
.collect::<Vec<_>>();
344
let out = s.reshape_list(&dims)?;
345
assert_eq!(out.len(), list_len);
346
assert!(matches!(out.dtype(), DataType::List(_)));
347
assert_eq!(
348
out.explode(ExplodeOptions {
349
empty_as_null: true,
350
keep_nulls: true,
351
})?
352
.len(),
353
4
354
);
355
}
356
357
Ok(())
358
}
359
}
360
361