Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/gather/sublist/list.rs
6940 views
1
use arrow::array::{Array, ArrayRef, ListArray};
2
use arrow::legacy::prelude::*;
3
use arrow::legacy::trusted_len::TrustedLenPush;
4
use arrow::legacy::utils::CustomIterTools;
5
use arrow::offset::{Offsets, OffsetsBuffer};
6
use polars_utils::IdxSize;
7
8
use crate::gather::take_unchecked;
9
10
/// Get the indices that would result in a get operation on the lists values.
11
/// for example, consider this list:
12
/// ```text
13
/// [[1, 2, 3],
14
/// [4, 5],
15
/// [6]]
16
///
17
/// This contains the following values array:
18
/// [1, 2, 3, 4, 5, 6]
19
///
20
/// get index 0
21
/// would lead to the following indexes:
22
/// [0, 3, 5].
23
/// if we use those in a take operation on the values array we get:
24
/// [1, 4, 6]
25
///
26
///
27
/// get index -1
28
/// would lead to the following indexes:
29
/// [2, 4, 5].
30
/// if we use those in a take operation on the values array we get:
31
/// [3, 5, 6]
32
///
33
/// ```
34
fn sublist_get_indexes(arr: &ListArray<i64>, index: i64) -> IdxArr {
35
let offsets = arr.offsets().as_slice();
36
let mut iter = offsets.iter();
37
38
// the indices can be sliced, so we should not start at 0.
39
let mut cum_offset = (*offsets.first().unwrap_or(&0)) as IdxSize;
40
41
if let Some(mut previous) = iter.next().copied() {
42
if arr.null_count() == 0 {
43
iter.map(|&offset| {
44
let len = offset - previous;
45
previous = offset;
46
// make sure that empty lists don't get accessed
47
// and out of bounds return null
48
if len == 0 {
49
return None;
50
}
51
if index >= len {
52
cum_offset += len as IdxSize;
53
return None;
54
}
55
56
let out = index
57
.negative_to_usize(len as usize)
58
.map(|idx| idx as IdxSize + cum_offset);
59
cum_offset += len as IdxSize;
60
out
61
})
62
.collect_trusted()
63
} else {
64
// we can ensure that validity is not none as we have null value.
65
let validity = arr.validity().unwrap();
66
iter.enumerate()
67
.map(|(i, &offset)| {
68
let len = offset - previous;
69
previous = offset;
70
// make sure that empty and null lists don't get accessed and return null.
71
// SAFETY, we are within bounds
72
if len == 0 || !unsafe { validity.get_bit_unchecked(i) } {
73
cum_offset += len as IdxSize;
74
return None;
75
}
76
77
// make sure that out of bounds return null
78
if index >= len {
79
cum_offset += len as IdxSize;
80
return None;
81
}
82
83
let out = index
84
.negative_to_usize(len as usize)
85
.map(|idx| idx as IdxSize + cum_offset);
86
cum_offset += len as IdxSize;
87
out
88
})
89
.collect_trusted()
90
}
91
} else {
92
IdxArr::from_slice([])
93
}
94
}
95
96
pub fn sublist_get(arr: &ListArray<i64>, index: i64) -> ArrayRef {
97
let take_by = sublist_get_indexes(arr, index);
98
let values = arr.values();
99
// SAFETY:
100
// the indices we generate are in bounds
101
unsafe { take_unchecked(&**values, &take_by) }
102
}
103
104
/// Check if an index is out of bounds for at least one sublist.
105
pub fn index_is_oob(arr: &ListArray<i64>, index: i64) -> bool {
106
if arr.null_count() == 0 {
107
arr.offsets()
108
.lengths()
109
.any(|len| index.negative_to_usize(len).is_none())
110
} else {
111
arr.offsets()
112
.lengths()
113
.zip(arr.validity().unwrap())
114
.any(|(len, valid)| {
115
if valid {
116
index.negative_to_usize(len).is_none()
117
} else {
118
// skip nulls
119
false
120
}
121
})
122
}
123
}
124
125
/// Convert a list `[1, 2, 3]` to a list type of `[[1], [2], [3]]`
126
pub fn array_to_unit_list(array: ArrayRef) -> ListArray<i64> {
127
let len = array.len();
128
let mut offsets = Vec::with_capacity(len + 1);
129
// SAFETY: we allocated enough
130
unsafe {
131
offsets.push_unchecked(0i64);
132
133
for _ in 0..len {
134
offsets.push_unchecked(offsets.len() as i64)
135
}
136
};
137
138
// SAFETY:
139
// offsets are monotonically increasing
140
unsafe {
141
let offsets: OffsetsBuffer<i64> = Offsets::new_unchecked(offsets).into();
142
let dtype = ListArray::<i64>::default_datatype(array.dtype().clone());
143
ListArray::<i64>::new(dtype, offsets, array, None)
144
}
145
}
146
147
#[cfg(test)]
148
mod test {
149
use arrow::array::{Int32Array, PrimitiveArray};
150
use arrow::datatypes::ArrowDataType;
151
152
use super::*;
153
154
fn get_array() -> ListArray<i64> {
155
let values = Int32Array::from_slice([1, 2, 3, 4, 5, 6]);
156
let offsets = OffsetsBuffer::try_from(vec![0i64, 3, 5, 6]).unwrap();
157
158
let dtype = ListArray::<i64>::default_datatype(ArrowDataType::Int32);
159
ListArray::<i64>::new(dtype, offsets, Box::new(values), None)
160
}
161
162
#[test]
163
fn test_sublist_get_indexes() {
164
let arr = get_array();
165
let out = sublist_get_indexes(&arr, 0);
166
assert_eq!(out.values().as_slice(), &[0, 3, 5]);
167
let out = sublist_get_indexes(&arr, -1);
168
assert_eq!(out.values().as_slice(), &[2, 4, 5]);
169
let out = sublist_get_indexes(&arr, 3);
170
assert_eq!(out.null_count(), 3);
171
172
let values = Int32Array::from_iter([
173
Some(1),
174
Some(1),
175
Some(3),
176
Some(4),
177
Some(5),
178
Some(6),
179
Some(7),
180
Some(8),
181
Some(9),
182
None,
183
Some(11),
184
]);
185
let offsets = OffsetsBuffer::try_from(vec![0i64, 1, 2, 3, 6, 9, 11]).unwrap();
186
187
let dtype = ListArray::<i64>::default_datatype(ArrowDataType::Int32);
188
let arr = ListArray::<i64>::new(dtype, offsets, Box::new(values), None);
189
190
let out = sublist_get_indexes(&arr, 1);
191
assert_eq!(
192
out.into_iter().collect::<Vec<_>>(),
193
&[None, None, None, Some(4), Some(7), Some(10)]
194
);
195
}
196
197
#[test]
198
fn test_sublist_get() {
199
let arr = get_array();
200
201
let out = sublist_get(&arr, 0);
202
let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap();
203
204
assert_eq!(out.values().as_slice(), &[1, 4, 6]);
205
let out = sublist_get(&arr, -1);
206
let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap();
207
assert_eq!(out.values().as_slice(), &[3, 5, 6]);
208
}
209
}
210
211