Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/comparisons/list.rs
8431 views
1
use arrow::array::{
2
Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray,
3
ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray,
4
};
5
use arrow::bitmap::Bitmap;
6
use arrow::legacy::utils::CustomIterTools;
7
use arrow::types::{Offset, days_ms, i256, months_days_ns};
8
use polars_utils::float16::pf16;
9
10
use super::TotalEqKernel;
11
12
macro_rules! compare {
13
(
14
$lhs:expr, $rhs:expr,
15
$op:path, $true_op:expr,
16
$ineq_len_rv:literal, $invalid_rv:literal
17
) => {{
18
let lhs = $lhs;
19
let rhs = $rhs;
20
21
assert_eq!(lhs.len(), rhs.len());
22
assert_eq!(lhs.dtype(), rhs.dtype());
23
24
macro_rules! call_binary {
25
($T:ty) => {{
26
let lhs_values: &$T = $lhs.values().as_any().downcast_ref().unwrap();
27
let rhs_values: &$T = $rhs.values().as_any().downcast_ref().unwrap();
28
29
(0..$lhs.len())
30
.map(|i| {
31
let lval = $lhs.validity().is_none_or(|v| v.get(i).unwrap());
32
let rval = $rhs.validity().is_none_or(|v| v.get(i).unwrap());
33
34
if !lval || !rval {
35
return $invalid_rv;
36
}
37
38
// SAFETY: ListArray's invariant offsets.len_proxy() == len
39
let (lstart, lend) = unsafe { $lhs.offsets().start_end_unchecked(i) };
40
let (rstart, rend) = unsafe { $rhs.offsets().start_end_unchecked(i) };
41
42
if lend - lstart != rend - rstart {
43
return $ineq_len_rv;
44
}
45
46
let mut lhs_values = lhs_values.clone();
47
lhs_values.slice(lstart, lend - lstart);
48
let mut rhs_values = rhs_values.clone();
49
rhs_values.slice(rstart, rend - rstart);
50
51
$true_op($op(&lhs_values, &rhs_values))
52
})
53
.collect_trusted()
54
}};
55
}
56
57
use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};
58
match lhs.values().dtype().to_physical_type() {
59
PH::Boolean => call_binary!(BooleanArray),
60
PH::BinaryView => call_binary!(BinaryViewArray),
61
PH::Utf8View => call_binary!(Utf8ViewArray),
62
PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),
63
PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),
64
PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),
65
PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),
66
PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),
67
PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),
68
PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),
69
PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),
70
PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),
71
PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),
72
PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<pf16>),
73
PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),
74
PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),
75
PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),
76
PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),
77
PH::Primitive(PR::MonthDayNano) => {
78
call_binary!(PrimitiveArray<months_days_ns>)
79
},
80
PH::Primitive(PR::MonthDayMillis) => unimplemented!(),
81
82
#[cfg(feature = "dtype-array")]
83
PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),
84
#[cfg(not(feature = "dtype-array"))]
85
PH::FixedSizeList => todo!(
86
"Comparison of FixedSizeListArray is not supported without dtype-array feature"
87
),
88
89
PH::Null => call_binary!(NullArray),
90
PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),
91
PH::Binary => call_binary!(BinaryArray<i32>),
92
PH::LargeBinary => call_binary!(BinaryArray<i64>),
93
PH::Utf8 => call_binary!(Utf8Array<i32>),
94
PH::LargeUtf8 => call_binary!(Utf8Array<i64>),
95
PH::List => call_binary!(ListArray<i32>),
96
PH::LargeList => call_binary!(ListArray<i64>),
97
PH::Struct => call_binary!(StructArray),
98
PH::Union => todo!("Comparison of UnionArrays is not yet supported"),
99
PH::Map => todo!("Comparison of MapArrays is not yet supported"),
100
PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),
101
PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),
102
PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),
103
PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),
104
PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),
105
PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),
106
PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),
107
PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),
108
PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),
109
PH::Dictionary(I::UInt128) => call_binary!(DictionaryArray<u128>),
110
}
111
}};
112
}
113
114
macro_rules! compare_broadcast {
115
(
116
$lhs:expr, $rhs:expr,
117
$offsets:expr, $validity:expr,
118
$op:path, $true_op:expr,
119
$ineq_len_rv:literal, $invalid_rv:literal
120
) => {{
121
let lhs = $lhs;
122
let rhs = $rhs;
123
124
macro_rules! call_binary {
125
($T:ty) => {{
126
let values: &$T = $lhs.as_any().downcast_ref().unwrap();
127
let scalar: &$T = $rhs.as_any().downcast_ref().unwrap();
128
129
let length = $offsets.len_proxy();
130
131
(0..length)
132
.map(move |i| {
133
let v = $validity.is_none_or(|v| v.get(i).unwrap());
134
135
if !v {
136
return $invalid_rv;
137
}
138
139
let (start, end) = unsafe { $offsets.start_end_unchecked(i) };
140
141
if end - start != scalar.len() {
142
return $ineq_len_rv;
143
}
144
145
// @TODO: I feel like there is a better way to do this.
146
let mut values: $T = values.clone();
147
<$T>::slice(&mut values, start, end - start);
148
149
$true_op($op(&values, scalar))
150
})
151
.collect_trusted()
152
}};
153
}
154
155
assert_eq!(lhs.dtype(), rhs.dtype());
156
157
use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};
158
match lhs.dtype().to_physical_type() {
159
PH::Boolean => call_binary!(BooleanArray),
160
PH::BinaryView => call_binary!(BinaryViewArray),
161
PH::Utf8View => call_binary!(Utf8ViewArray),
162
PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),
163
PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),
164
PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),
165
PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),
166
PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),
167
PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),
168
PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),
169
PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),
170
PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),
171
PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),
172
PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<pf16>),
173
PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),
174
PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),
175
PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),
176
PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),
177
PH::Primitive(PR::MonthDayNano) => {
178
call_binary!(PrimitiveArray<months_days_ns>)
179
},
180
PH::Primitive(PR::MonthDayMillis) => unimplemented!(),
181
182
#[cfg(feature = "dtype-array")]
183
PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),
184
#[cfg(not(feature = "dtype-array"))]
185
PH::FixedSizeList => todo!(
186
"Comparison of FixedSizeListArray is not supported without dtype-array feature"
187
),
188
189
PH::Null => call_binary!(NullArray),
190
PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),
191
PH::Binary => call_binary!(BinaryArray<i32>),
192
PH::LargeBinary => call_binary!(BinaryArray<i64>),
193
PH::Utf8 => call_binary!(Utf8Array<i32>),
194
PH::LargeUtf8 => call_binary!(Utf8Array<i64>),
195
PH::List => call_binary!(ListArray<i32>),
196
PH::LargeList => call_binary!(ListArray<i64>),
197
PH::Struct => call_binary!(StructArray),
198
PH::Union => todo!("Comparison of UnionArrays is not yet supported"),
199
PH::Map => todo!("Comparison of MapArrays is not yet supported"),
200
PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),
201
PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),
202
PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),
203
PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),
204
PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),
205
PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),
206
PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),
207
PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),
208
PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),
209
PH::Dictionary(I::UInt128) => call_binary!(DictionaryArray<u128>),
210
}
211
}};
212
}
213
214
impl<O: Offset> TotalEqKernel for ListArray<O> {
215
type Scalar = Box<dyn Array>;
216
217
fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
218
compare!(
219
self,
220
other,
221
TotalEqKernel::tot_eq_missing_kernel,
222
|bm: Bitmap| bm.unset_bits() == 0,
223
false,
224
true
225
)
226
}
227
228
fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
229
compare!(
230
self,
231
other,
232
TotalEqKernel::tot_ne_missing_kernel,
233
|bm: Bitmap| bm.set_bits() > 0,
234
true,
235
false
236
)
237
}
238
239
fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
240
compare_broadcast!(
241
self.values().as_ref(),
242
other.as_ref(),
243
self.offsets(),
244
self.validity(),
245
TotalEqKernel::tot_eq_missing_kernel,
246
|bm: Bitmap| bm.unset_bits() == 0,
247
false,
248
true
249
)
250
}
251
252
fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
253
compare_broadcast!(
254
self.values().as_ref(),
255
other.as_ref(),
256
self.offsets(),
257
self.validity(),
258
TotalEqKernel::tot_ne_missing_kernel,
259
|bm: Bitmap| bm.set_bits() > 0,
260
true,
261
false
262
)
263
}
264
}
265
266