Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/array/union/mod.rs
6939 views
1
use polars_error::{PolarsResult, polars_bail, polars_err};
2
3
use super::{Array, Splitable, new_empty_array, new_null_array};
4
use crate::bitmap::Bitmap;
5
use crate::buffer::Buffer;
6
use crate::datatypes::{ArrowDataType, Field, UnionMode};
7
use crate::scalar::{Scalar, new_scalar};
8
9
mod ffi;
10
pub(super) mod fmt;
11
mod iterator;
12
13
type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode);
14
15
/// [`UnionArray`] represents an array whose each slot can contain different values.
16
///
17
// How to read a value at slot i:
18
// ```
19
// let index = self.types()[i] as usize;
20
// let field = self.fields()[index];
21
// let offset = self.offsets().map(|x| x[index]).unwrap_or(i);
22
// let field = field.as_any().downcast to correct type;
23
// let value = field.value(offset);
24
// ```
25
#[derive(Clone)]
26
pub struct UnionArray {
27
// Invariant: every item in `types` is `> 0 && < fields.len()`
28
types: Buffer<i8>,
29
// Invariant: `map.len() == fields.len()`
30
// Invariant: every item in `map` is `> 0 && < fields.len()`
31
map: Option<[usize; 127]>,
32
fields: Vec<Box<dyn Array>>,
33
// Invariant: when set, `offsets.len() == types.len()`
34
offsets: Option<Buffer<i32>>,
35
dtype: ArrowDataType,
36
offset: usize,
37
}
38
39
impl UnionArray {
40
/// Returns a new [`UnionArray`].
41
/// # Errors
42
/// This function errors iff:
43
/// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Union`].
44
/// * the fields's len is different from the `dtype`'s children's length
45
/// * The number of `fields` is larger than `i8::MAX`
46
/// * any of the values's data type is different from its corresponding children' data type
47
pub fn try_new(
48
dtype: ArrowDataType,
49
types: Buffer<i8>,
50
fields: Vec<Box<dyn Array>>,
51
offsets: Option<Buffer<i32>>,
52
) -> PolarsResult<Self> {
53
let (f, ids, mode) = Self::try_get_all(&dtype)?;
54
55
if f.len() != fields.len() {
56
polars_bail!(ComputeError: "the number of `fields` must equal the number of children fields in DataType::Union")
57
};
58
let number_of_fields: i8 = fields.len().try_into().map_err(
59
|_| polars_err!(ComputeError: "the number of `fields` cannot be larger than i8::MAX"),
60
)?;
61
62
f
63
.iter().map(|a| a.dtype())
64
.zip(fields.iter().map(|a| a.dtype()))
65
.enumerate()
66
.try_for_each(|(index, (dtype, child))| {
67
if dtype != child {
68
polars_bail!(ComputeError:
69
"the children DataTypes of a UnionArray must equal the children data types.
70
However, the field {index} has data type {dtype:?} but the value has data type {child:?}"
71
)
72
} else {
73
Ok(())
74
}
75
})?;
76
77
if let Some(offsets) = &offsets {
78
if offsets.len() != types.len() {
79
polars_bail!(ComputeError:
80
"in a UnionArray, the offsets' length must be equal to the number of types"
81
)
82
}
83
}
84
if offsets.is_none() != mode.is_sparse() {
85
polars_bail!(ComputeError:
86
"in a sparse UnionArray, the offsets must be set (and vice-versa)",
87
)
88
}
89
90
// build hash
91
let map = if let Some(&ids) = ids.as_ref() {
92
if ids.len() != fields.len() {
93
polars_bail!(ComputeError:
94
"in a union, when the ids are set, their length must be equal to the number of fields",
95
)
96
}
97
98
// example:
99
// * types = [5, 7, 5, 7, 7, 7, 5, 7, 7, 5, 5]
100
// * ids = [5, 7]
101
// => hash = [0, 0, 0, 0, 0, 0, 1, 0, ...]
102
let mut hash = [0; 127];
103
104
for (pos, &id) in ids.iter().enumerate() {
105
if !(0..=127).contains(&id) {
106
polars_bail!(ComputeError:
107
"in a union, when the ids are set, every id must belong to [0, 128[",
108
)
109
}
110
hash[id as usize] = pos;
111
}
112
113
types.iter().try_for_each(|&type_| {
114
if type_ < 0 {
115
polars_bail!(ComputeError:
116
"in a union, when the ids are set, every type must be >= 0"
117
)
118
}
119
let id = hash[type_ as usize];
120
if id >= fields.len() {
121
polars_bail!(ComputeError:
122
"in a union, when the ids are set, each id must be smaller than the number of fields."
123
)
124
} else {
125
Ok(())
126
}
127
})?;
128
129
Some(hash)
130
} else {
131
// SAFETY: every type in types is smaller than number of fields
132
let mut is_valid = true;
133
for &type_ in types.iter() {
134
if type_ < 0 || type_ >= number_of_fields {
135
is_valid = false
136
}
137
}
138
if !is_valid {
139
polars_bail!(ComputeError:
140
"every type in `types` must be larger than 0 and smaller than the number of fields.",
141
)
142
}
143
144
None
145
};
146
147
Ok(Self {
148
dtype,
149
map,
150
fields,
151
offsets,
152
types,
153
offset: 0,
154
})
155
}
156
157
/// Returns a new [`UnionArray`].
158
/// # Panics
159
/// This function panics iff:
160
/// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Union`].
161
/// * the fields's len is different from the `dtype`'s children's length
162
/// * any of the values's data type is different from its corresponding children' data type
163
pub fn new(
164
dtype: ArrowDataType,
165
types: Buffer<i8>,
166
fields: Vec<Box<dyn Array>>,
167
offsets: Option<Buffer<i32>>,
168
) -> Self {
169
Self::try_new(dtype, types, fields, offsets).unwrap()
170
}
171
172
/// Creates a new null [`UnionArray`].
173
pub fn new_null(dtype: ArrowDataType, length: usize) -> Self {
174
if let ArrowDataType::Union(u) = &dtype {
175
let fields = u
176
.fields
177
.iter()
178
.map(|x| new_null_array(x.dtype().clone(), length))
179
.collect();
180
181
let offsets = if u.mode.is_sparse() {
182
None
183
} else {
184
Some((0..length as i32).collect::<Vec<_>>().into())
185
};
186
187
// all from the same field
188
let types = vec![0i8; length].into();
189
190
Self::new(dtype, types, fields, offsets)
191
} else {
192
panic!("Union struct must be created with the corresponding Union DataType")
193
}
194
}
195
196
/// Creates a new empty [`UnionArray`].
197
pub fn new_empty(dtype: ArrowDataType) -> Self {
198
if let ArrowDataType::Union(u) = dtype.to_logical_type() {
199
let fields = u
200
.fields
201
.iter()
202
.map(|x| new_empty_array(x.dtype().clone()))
203
.collect();
204
205
let offsets = if u.mode.is_sparse() {
206
None
207
} else {
208
Some(Buffer::default())
209
};
210
211
Self {
212
dtype,
213
map: None,
214
fields,
215
offsets,
216
types: Buffer::new(),
217
offset: 0,
218
}
219
} else {
220
panic!("Union struct must be created with the corresponding Union DataType")
221
}
222
}
223
}
224
225
impl UnionArray {
226
/// Returns a slice of this [`UnionArray`].
227
/// # Implementation
228
/// This operation is `O(F)` where `F` is the number of fields.
229
/// # Panic
230
/// This function panics iff `offset + length > self.len()`.
231
#[inline]
232
pub fn slice(&mut self, offset: usize, length: usize) {
233
assert!(
234
offset + length <= self.len(),
235
"the offset of the new array cannot exceed the existing length"
236
);
237
unsafe { self.slice_unchecked(offset, length) }
238
}
239
240
/// Returns a slice of this [`UnionArray`].
241
/// # Implementation
242
/// This operation is `O(F)` where `F` is the number of fields.
243
///
244
/// # Safety
245
/// The caller must ensure that `offset + length <= self.len()`.
246
#[inline]
247
pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) {
248
debug_assert!(offset + length <= self.len());
249
250
self.types.slice_unchecked(offset, length);
251
if let Some(offsets) = self.offsets.as_mut() {
252
offsets.slice_unchecked(offset, length)
253
}
254
self.offset += offset;
255
}
256
257
impl_sliced!();
258
impl_into_array!();
259
}
260
261
impl UnionArray {
262
/// Returns the length of this array
263
#[inline]
264
pub fn len(&self) -> usize {
265
self.types.len()
266
}
267
268
/// The optional offsets.
269
pub fn offsets(&self) -> Option<&Buffer<i32>> {
270
self.offsets.as_ref()
271
}
272
273
/// The fields.
274
pub fn fields(&self) -> &Vec<Box<dyn Array>> {
275
&self.fields
276
}
277
278
/// The types.
279
pub fn types(&self) -> &Buffer<i8> {
280
&self.types
281
}
282
283
#[inline]
284
unsafe fn field_slot_unchecked(&self, index: usize) -> usize {
285
self.offsets()
286
.as_ref()
287
.map(|x| *x.get_unchecked(index) as usize)
288
.unwrap_or(index + self.offset)
289
}
290
291
/// Returns the index and slot of the field to select from `self.fields`.
292
#[inline]
293
pub fn index(&self, index: usize) -> (usize, usize) {
294
assert!(index < self.len());
295
unsafe { self.index_unchecked(index) }
296
}
297
298
/// Returns the index and slot of the field to select from `self.fields`.
299
/// The first value is guaranteed to be `< self.fields().len()`
300
///
301
/// # Safety
302
/// This function is safe iff `index < self.len`.
303
#[inline]
304
pub unsafe fn index_unchecked(&self, index: usize) -> (usize, usize) {
305
debug_assert!(index < self.len());
306
// SAFETY: assumption of the function
307
let type_ = unsafe { *self.types.get_unchecked(index) };
308
// SAFETY: assumption of the struct
309
let type_ = self
310
.map
311
.as_ref()
312
.map(|map| unsafe { *map.get_unchecked(type_ as usize) })
313
.unwrap_or(type_ as usize);
314
// SAFETY: assumption of the function
315
let index = self.field_slot_unchecked(index);
316
(type_, index)
317
}
318
319
/// Returns the slot `index` as a [`Scalar`].
320
/// # Panics
321
/// iff `index >= self.len()`
322
pub fn value(&self, index: usize) -> Box<dyn Scalar> {
323
assert!(index < self.len());
324
unsafe { self.value_unchecked(index) }
325
}
326
327
/// Returns the slot `index` as a [`Scalar`].
328
///
329
/// # Safety
330
/// This function is safe iff `i < self.len`.
331
pub unsafe fn value_unchecked(&self, index: usize) -> Box<dyn Scalar> {
332
debug_assert!(index < self.len());
333
let (type_, index) = self.index_unchecked(index);
334
// SAFETY: assumption of the struct
335
debug_assert!(type_ < self.fields.len());
336
let field = self.fields.get_unchecked(type_).as_ref();
337
new_scalar(field, index)
338
}
339
}
340
341
impl Array for UnionArray {
342
impl_common_array!();
343
344
fn validity(&self) -> Option<&Bitmap> {
345
None
346
}
347
348
fn with_validity(&self, _: Option<Bitmap>) -> Box<dyn Array> {
349
panic!("cannot set validity of a union array")
350
}
351
}
352
353
impl UnionArray {
354
fn try_get_all(dtype: &ArrowDataType) -> PolarsResult<UnionComponents<'_>> {
355
match dtype.to_logical_type() {
356
ArrowDataType::Union(u) => Ok((&u.fields, u.ids.as_ref().map(|x| x.as_ref()), u.mode)),
357
_ => polars_bail!(ComputeError:
358
"The UnionArray requires a logical type of DataType::Union",
359
),
360
}
361
}
362
363
fn get_all(dtype: &ArrowDataType) -> (&[Field], Option<&[i32]>, UnionMode) {
364
Self::try_get_all(dtype).unwrap()
365
}
366
367
/// Returns all fields from [`ArrowDataType::Union`].
368
/// # Panic
369
/// Panics iff `dtype`'s logical type is not [`ArrowDataType::Union`].
370
pub fn get_fields(dtype: &ArrowDataType) -> &[Field] {
371
Self::get_all(dtype).0
372
}
373
374
/// Returns whether the [`ArrowDataType::Union`] is sparse or not.
375
/// # Panic
376
/// Panics iff `dtype`'s logical type is not [`ArrowDataType::Union`].
377
pub fn is_sparse(dtype: &ArrowDataType) -> bool {
378
Self::get_all(dtype).2.is_sparse()
379
}
380
}
381
382
impl Splitable for UnionArray {
383
fn check_bound(&self, offset: usize) -> bool {
384
offset <= self.len()
385
}
386
387
unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) {
388
let (lhs_types, rhs_types) = unsafe { self.types.split_at_unchecked(offset) };
389
let (lhs_offsets, rhs_offsets) = self.offsets.as_ref().map_or((None, None), |v| {
390
let (lhs, rhs) = unsafe { v.split_at_unchecked(offset) };
391
(Some(lhs), Some(rhs))
392
});
393
394
(
395
Self {
396
types: lhs_types,
397
map: self.map,
398
fields: self.fields.clone(),
399
offsets: lhs_offsets,
400
dtype: self.dtype.clone(),
401
offset: self.offset,
402
},
403
Self {
404
types: rhs_types,
405
map: self.map,
406
fields: self.fields.clone(),
407
offsets: rhs_offsets,
408
dtype: self.dtype.clone(),
409
offset: self.offset + offset,
410
},
411
)
412
}
413
}
414
415