Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/comparisons/list.rs
6939 views
1
use arrow::array::{
2
Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray,
3
ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray,
4
};
5
use arrow::bitmap::Bitmap;
6
use arrow::legacy::utils::CustomIterTools;
7
use arrow::types::{Offset, days_ms, f16, i256, months_days_ns};
8
9
use super::TotalEqKernel;
10
11
macro_rules! compare {
12
(
13
$lhs:expr, $rhs:expr,
14
$op:path, $true_op:expr,
15
$ineq_len_rv:literal, $invalid_rv:literal
16
) => {{
17
let lhs = $lhs;
18
let rhs = $rhs;
19
20
assert_eq!(lhs.len(), rhs.len());
21
assert_eq!(lhs.dtype(), rhs.dtype());
22
23
macro_rules! call_binary {
24
($T:ty) => {{
25
let lhs_values: &$T = $lhs.values().as_any().downcast_ref().unwrap();
26
let rhs_values: &$T = $rhs.values().as_any().downcast_ref().unwrap();
27
28
(0..$lhs.len())
29
.map(|i| {
30
let lval = $lhs.validity().is_none_or(|v| v.get(i).unwrap());
31
let rval = $rhs.validity().is_none_or(|v| v.get(i).unwrap());
32
33
if !lval || !rval {
34
return $invalid_rv;
35
}
36
37
// SAFETY: ListArray's invariant offsets.len_proxy() == len
38
let (lstart, lend) = unsafe { $lhs.offsets().start_end_unchecked(i) };
39
let (rstart, rend) = unsafe { $rhs.offsets().start_end_unchecked(i) };
40
41
if lend - lstart != rend - rstart {
42
return $ineq_len_rv;
43
}
44
45
let mut lhs_values = lhs_values.clone();
46
lhs_values.slice(lstart, lend - lstart);
47
let mut rhs_values = rhs_values.clone();
48
rhs_values.slice(rstart, rend - rstart);
49
50
$true_op($op(&lhs_values, &rhs_values))
51
})
52
.collect_trusted()
53
}};
54
}
55
56
use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};
57
match lhs.values().dtype().to_physical_type() {
58
PH::Boolean => call_binary!(BooleanArray),
59
PH::BinaryView => call_binary!(BinaryViewArray),
60
PH::Utf8View => call_binary!(Utf8ViewArray),
61
PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),
62
PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),
63
PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),
64
PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),
65
PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),
66
PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),
67
PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),
68
PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),
69
PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),
70
PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),
71
PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<f16>),
72
PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),
73
PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),
74
PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),
75
PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),
76
PH::Primitive(PR::MonthDayNano) => {
77
call_binary!(PrimitiveArray<months_days_ns>)
78
},
79
80
#[cfg(feature = "dtype-array")]
81
PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),
82
#[cfg(not(feature = "dtype-array"))]
83
PH::FixedSizeList => todo!(
84
"Comparison of FixedSizeListArray is not supported without dtype-array feature"
85
),
86
87
PH::Null => call_binary!(NullArray),
88
PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),
89
PH::Binary => call_binary!(BinaryArray<i32>),
90
PH::LargeBinary => call_binary!(BinaryArray<i64>),
91
PH::Utf8 => call_binary!(Utf8Array<i32>),
92
PH::LargeUtf8 => call_binary!(Utf8Array<i64>),
93
PH::List => call_binary!(ListArray<i32>),
94
PH::LargeList => call_binary!(ListArray<i64>),
95
PH::Struct => call_binary!(StructArray),
96
PH::Union => todo!("Comparison of UnionArrays is not yet supported"),
97
PH::Map => todo!("Comparison of MapArrays is not yet supported"),
98
PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),
99
PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),
100
PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),
101
PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),
102
PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),
103
PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),
104
PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),
105
PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),
106
PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),
107
}
108
}};
109
}
110
111
macro_rules! compare_broadcast {
112
(
113
$lhs:expr, $rhs:expr,
114
$offsets:expr, $validity:expr,
115
$op:path, $true_op:expr,
116
$ineq_len_rv:literal, $invalid_rv:literal
117
) => {{
118
let lhs = $lhs;
119
let rhs = $rhs;
120
121
macro_rules! call_binary {
122
($T:ty) => {{
123
let values: &$T = $lhs.as_any().downcast_ref().unwrap();
124
let scalar: &$T = $rhs.as_any().downcast_ref().unwrap();
125
126
let length = $offsets.len_proxy();
127
128
(0..length)
129
.map(move |i| {
130
let v = $validity.is_none_or(|v| v.get(i).unwrap());
131
132
if !v {
133
return $invalid_rv;
134
}
135
136
let (start, end) = unsafe { $offsets.start_end_unchecked(i) };
137
138
if end - start != scalar.len() {
139
return $ineq_len_rv;
140
}
141
142
// @TODO: I feel like there is a better way to do this.
143
let mut values: $T = values.clone();
144
<$T>::slice(&mut values, start, end - start);
145
146
$true_op($op(&values, scalar))
147
})
148
.collect_trusted()
149
}};
150
}
151
152
assert_eq!(lhs.dtype(), rhs.dtype());
153
154
use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};
155
match lhs.dtype().to_physical_type() {
156
PH::Boolean => call_binary!(BooleanArray),
157
PH::BinaryView => call_binary!(BinaryViewArray),
158
PH::Utf8View => call_binary!(Utf8ViewArray),
159
PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),
160
PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),
161
PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),
162
PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),
163
PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),
164
PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),
165
PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),
166
PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),
167
PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),
168
PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),
169
PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<f16>),
170
PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),
171
PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),
172
PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),
173
PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),
174
PH::Primitive(PR::MonthDayNano) => {
175
call_binary!(PrimitiveArray<months_days_ns>)
176
},
177
178
#[cfg(feature = "dtype-array")]
179
PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),
180
#[cfg(not(feature = "dtype-array"))]
181
PH::FixedSizeList => todo!(
182
"Comparison of FixedSizeListArray is not supported without dtype-array feature"
183
),
184
185
PH::Null => call_binary!(NullArray),
186
PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),
187
PH::Binary => call_binary!(BinaryArray<i32>),
188
PH::LargeBinary => call_binary!(BinaryArray<i64>),
189
PH::Utf8 => call_binary!(Utf8Array<i32>),
190
PH::LargeUtf8 => call_binary!(Utf8Array<i64>),
191
PH::List => call_binary!(ListArray<i32>),
192
PH::LargeList => call_binary!(ListArray<i64>),
193
PH::Struct => call_binary!(StructArray),
194
PH::Union => todo!("Comparison of UnionArrays is not yet supported"),
195
PH::Map => todo!("Comparison of MapArrays is not yet supported"),
196
PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),
197
PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),
198
PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),
199
PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),
200
PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),
201
PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),
202
PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),
203
PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),
204
PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),
205
}
206
}};
207
}
208
209
impl<O: Offset> TotalEqKernel for ListArray<O> {
210
type Scalar = Box<dyn Array>;
211
212
fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
213
compare!(
214
self,
215
other,
216
TotalEqKernel::tot_eq_missing_kernel,
217
|bm: Bitmap| bm.unset_bits() == 0,
218
false,
219
true
220
)
221
}
222
223
fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
224
compare!(
225
self,
226
other,
227
TotalEqKernel::tot_ne_missing_kernel,
228
|bm: Bitmap| bm.set_bits() > 0,
229
true,
230
false
231
)
232
}
233
234
fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
235
compare_broadcast!(
236
self.values().as_ref(),
237
other.as_ref(),
238
self.offsets(),
239
self.validity(),
240
TotalEqKernel::tot_eq_missing_kernel,
241
|bm: Bitmap| bm.unset_bits() == 0,
242
false,
243
true
244
)
245
}
246
247
fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
248
compare_broadcast!(
249
self.values().as_ref(),
250
other.as_ref(),
251
self.offsets(),
252
self.validity(),
253
TotalEqKernel::tot_ne_missing_kernel,
254
|bm: Bitmap| bm.set_bits() > 0,
255
true,
256
false
257
)
258
}
259
}
260
261