Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/comparisons/array.rs
8398 views
1
use arrow::array::{
2
Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray,
3
FixedSizeListArray, ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array,
4
Utf8ViewArray,
5
};
6
use arrow::bitmap::Bitmap;
7
use arrow::bitmap::utils::count_zeros;
8
use arrow::datatypes::ArrowDataType;
9
use arrow::legacy::utils::CustomIterTools;
10
use arrow::types::{days_ms, i256, months_days_ns};
11
use polars_utils::float16::pf16;
12
13
use super::TotalEqKernel;
14
use crate::comparisons::dyn_array::{array_tot_eq_missing_kernel, array_tot_ne_missing_kernel};
15
16
/// Condenses a bitmap of n * width elements into one with n elements.
17
///
18
/// For each block of width bits a zero count is done. The block of bits is then
19
/// replaced with a single bit: the result of true_zero_count(zero_count).
20
fn agg_array_bitmap<F>(bm: Bitmap, width: usize, true_zero_count: F) -> Bitmap
21
where
22
F: Fn(usize) -> bool,
23
{
24
if bm.len() == 1 {
25
bm
26
} else {
27
assert!(width > 0 && bm.len().is_multiple_of(width));
28
29
let (slice, offset, _len) = bm.as_slice();
30
(0..bm.len() / width)
31
.map(|i| true_zero_count(count_zeros(slice, offset + i * width, width)))
32
.collect()
33
}
34
}
35
36
impl TotalEqKernel for FixedSizeListArray {
37
type Scalar = Box<dyn Array>;
38
39
fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
40
// Nested comparison always done with eq_missing, propagating doesn't
41
// make any sense.
42
43
assert_eq!(self.len(), other.len());
44
let ArrowDataType::FixedSizeList(self_type, self_width) = self.dtype().to_storage() else {
45
panic!("array comparison called with non-array type");
46
};
47
let ArrowDataType::FixedSizeList(other_type, other_width) = other.dtype().to_storage()
48
else {
49
panic!("array comparison called with non-array type");
50
};
51
assert_eq!(self_type.dtype(), other_type.dtype());
52
53
if self_width != other_width {
54
return Bitmap::new_with_value(false, self.len());
55
}
56
57
if *self_width == 0 {
58
return Bitmap::new_with_value(true, self.len());
59
}
60
61
// @TODO: It is probably worth it to dispatch to a special kernel for when there are
62
// several nested arrays because that can be rather slow with this code.
63
let inner = array_tot_eq_missing_kernel(self.values().as_ref(), other.values().as_ref());
64
65
agg_array_bitmap(inner, self.size(), |zeroes| zeroes == 0)
66
}
67
68
fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
69
assert_eq!(self.len(), other.len());
70
let ArrowDataType::FixedSizeList(self_type, self_width) = self.dtype().to_storage() else {
71
panic!("array comparison called with non-array type");
72
};
73
let ArrowDataType::FixedSizeList(other_type, other_width) = other.dtype().to_storage()
74
else {
75
panic!("array comparison called with non-array type");
76
};
77
assert_eq!(self_type.dtype(), other_type.dtype());
78
79
if self_width != other_width {
80
return Bitmap::new_with_value(true, self.len());
81
}
82
83
if *self_width == 0 {
84
return Bitmap::new_with_value(false, self.len());
85
}
86
87
// @TODO: It is probably worth it to dispatch to a special kernel for when there are
88
// several nested arrays because that can be rather slow with this code.
89
let inner = array_tot_ne_missing_kernel(self.values().as_ref(), other.values().as_ref());
90
91
agg_array_bitmap(inner, self.size(), |zeroes| zeroes < self.size())
92
}
93
94
fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
95
let ArrowDataType::FixedSizeList(self_type, width) = self.dtype().to_storage() else {
96
panic!("array comparison called with non-array type");
97
};
98
assert_eq!(self_type.dtype(), other.dtype().to_storage());
99
100
let width = *width;
101
102
if width != other.len() {
103
return Bitmap::new_with_value(false, self.len());
104
}
105
106
if width == 0 {
107
return Bitmap::new_with_value(true, self.len());
108
}
109
110
// @TODO: It is probably worth it to dispatch to a special kernel for when there are
111
// several nested arrays because that can be rather slow with this code.
112
array_fsl_tot_eq_missing_kernel(self.values().as_ref(), other.as_ref(), self.len(), width)
113
}
114
115
fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
116
let ArrowDataType::FixedSizeList(self_type, width) = self.dtype().to_storage() else {
117
panic!("array comparison called with non-array type");
118
};
119
assert_eq!(self_type.dtype(), other.dtype().to_storage());
120
121
let width = *width;
122
123
if width != other.len() {
124
return Bitmap::new_with_value(true, self.len());
125
}
126
127
if width == 0 {
128
return Bitmap::new_with_value(false, self.len());
129
}
130
131
// @TODO: It is probably worth it to dispatch to a special kernel for when there are
132
// several nested arrays because that can be rather slow with this code.
133
array_fsl_tot_ne_missing_kernel(self.values().as_ref(), other.as_ref(), self.len(), width)
134
}
135
}
136
137
macro_rules! compare {
138
($lhs:expr, $rhs:expr, $length:expr, $width:expr, $op:path, $true_op:expr) => {{
139
let lhs = $lhs;
140
let rhs = $rhs;
141
142
macro_rules! call_binary {
143
($T:ty) => {{
144
let values: &$T = $lhs.as_any().downcast_ref().unwrap();
145
let scalar: &$T = $rhs.as_any().downcast_ref().unwrap();
146
147
(0..$length)
148
.map(move |i| {
149
// @TODO: I feel like there is a better way to do this.
150
let mut values: $T = values.clone();
151
<$T>::slice(&mut values, i * $width, $width);
152
153
$true_op($op(&values, scalar))
154
})
155
.collect_trusted()
156
}};
157
}
158
159
assert_eq!(lhs.dtype(), rhs.dtype());
160
161
use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR};
162
match lhs.dtype().to_physical_type() {
163
PH::Boolean => call_binary!(BooleanArray),
164
PH::BinaryView => call_binary!(BinaryViewArray),
165
PH::Utf8View => call_binary!(Utf8ViewArray),
166
PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>),
167
PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>),
168
PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>),
169
PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>),
170
PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>),
171
PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>),
172
PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>),
173
PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>),
174
PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>),
175
PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>),
176
PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray<pf16>),
177
PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>),
178
PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>),
179
PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray<i256>),
180
PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray<days_ms>),
181
PH::Primitive(PR::MonthDayNano) => {
182
call_binary!(PrimitiveArray<months_days_ns>)
183
},
184
PH::Primitive(PR::MonthDayMillis) => unimplemented!(),
185
186
#[cfg(feature = "dtype-array")]
187
PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray),
188
#[cfg(not(feature = "dtype-array"))]
189
PH::FixedSizeList => todo!(
190
"Comparison of FixedSizeListArray is not supported without dtype-array feature"
191
),
192
193
PH::Null => call_binary!(NullArray),
194
PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray),
195
PH::Binary => call_binary!(BinaryArray<i32>),
196
PH::LargeBinary => call_binary!(BinaryArray<i64>),
197
PH::Utf8 => call_binary!(Utf8Array<i32>),
198
PH::LargeUtf8 => call_binary!(Utf8Array<i64>),
199
PH::List => call_binary!(ListArray<i32>),
200
PH::LargeList => call_binary!(ListArray<i64>),
201
PH::Struct => call_binary!(StructArray),
202
PH::Union => todo!("Comparison of UnionArrays is not yet supported"),
203
PH::Map => todo!("Comparison of MapArrays is not yet supported"),
204
PH::Dictionary(I::Int8) => call_binary!(DictionaryArray<i8>),
205
PH::Dictionary(I::Int16) => call_binary!(DictionaryArray<i16>),
206
PH::Dictionary(I::Int32) => call_binary!(DictionaryArray<i32>),
207
PH::Dictionary(I::Int64) => call_binary!(DictionaryArray<i64>),
208
PH::Dictionary(I::Int128) => call_binary!(DictionaryArray<i128>),
209
PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray<u8>),
210
PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray<u16>),
211
PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray<u32>),
212
PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray<u64>),
213
PH::Dictionary(I::UInt128) => call_binary!(DictionaryArray<u128>),
214
}
215
}};
216
}
217
218
fn array_fsl_tot_eq_missing_kernel(
219
values: &dyn Array,
220
scalar: &dyn Array,
221
length: usize,
222
width: usize,
223
) -> Bitmap {
224
// @NOTE: Zero-Width Array are handled before
225
debug_assert_eq!(values.len(), length * width);
226
debug_assert_eq!(scalar.len(), width);
227
228
compare!(
229
values,
230
scalar,
231
length,
232
width,
233
TotalEqKernel::tot_eq_missing_kernel,
234
|bm: Bitmap| bm.unset_bits() == 0
235
)
236
}
237
238
fn array_fsl_tot_ne_missing_kernel(
239
values: &dyn Array,
240
scalar: &dyn Array,
241
length: usize,
242
width: usize,
243
) -> Bitmap {
244
// @NOTE: Zero-Width Array are handled before
245
debug_assert_eq!(values.len(), length * width);
246
debug_assert_eq!(scalar.len(), width);
247
248
compare!(
249
values,
250
scalar,
251
length,
252
width,
253
TotalEqKernel::tot_ne_missing_kernel,
254
|bm: Bitmap| bm.set_bits() > 0
255
)
256
}
257
258