Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-compute/src/find_validity_mismatch.rs
6939 views
1
use arrow::array::{Array, FixedSizeListArray, ListArray, StructArray};
2
use arrow::datatypes::ArrowDataType;
3
use arrow::types::Offset;
4
use polars_utils::IdxSize;
5
use polars_utils::itertools::Itertools;
6
7
use crate::cast::CastOptionsImpl;
8
9
/// Find the indices of the values where the validity mismatches.
10
///
11
/// This is done recursively, meaning that a validity mismatch at a deeper level will result as at
12
/// the level above at the corresponding index.
13
///
14
/// This procedure requires that
15
/// - Nulls are propagated recursively
16
/// - Lists to be
17
/// - trimmed to normalized offsets
18
/// - have the same number of child elements below each element (even nulls)
19
pub fn find_validity_mismatch(left: &dyn Array, right: &dyn Array, idxs: &mut Vec<IdxSize>) {
20
assert_eq!(left.len(), right.len());
21
22
// Handle the top-level.
23
//
24
// NOTE: This is done always, even if left and right have different nestings. This is
25
// intentional and needed.
26
let original_idxs_length = idxs.len();
27
match (left.validity(), right.validity()) {
28
(None, None) => {},
29
(Some(l), Some(r)) => {
30
if l != r {
31
let mismatches = arrow::bitmap::xor(l, r);
32
idxs.extend(mismatches.true_idx_iter().map(|i| i as IdxSize));
33
}
34
},
35
(Some(v), _) | (_, Some(v)) => {
36
if v.unset_bits() > 0 {
37
let mismatches = !v;
38
idxs.extend(mismatches.true_idx_iter().map(|i| i as IdxSize));
39
}
40
},
41
}
42
43
let left = left.as_any();
44
let right = right.as_any();
45
46
let pre_nesting_length = idxs.len();
47
// (Struct, Struct)
48
if let (Some(left), Some(right)) = (
49
left.downcast_ref::<StructArray>(),
50
right.downcast_ref::<StructArray>(),
51
) {
52
assert_eq!(left.fields().len(), right.fields().len());
53
for (l, r) in left.values().iter().zip(right.values().iter()) {
54
find_validity_mismatch(l.as_ref(), r.as_ref(), idxs);
55
}
56
}
57
58
// (List, List)
59
if let (Some(left), Some(right)) = (
60
left.downcast_ref::<ListArray<i32>>(),
61
right.downcast_ref::<ListArray<i32>>(),
62
) {
63
find_validity_mismatch_list_list_nested(left, right, idxs);
64
}
65
if let (Some(left), Some(right)) = (
66
left.downcast_ref::<ListArray<i64>>(),
67
right.downcast_ref::<ListArray<i64>>(),
68
) {
69
find_validity_mismatch_list_list_nested(left, right, idxs);
70
}
71
72
// (FixedSizeList, FixedSizeList)
73
if let (Some(left), Some(right)) = (
74
left.downcast_ref::<FixedSizeListArray>(),
75
right.downcast_ref::<FixedSizeListArray>(),
76
) {
77
assert_eq!(left.size(), right.size());
78
find_validity_mismatch_fsl_fsl_nested(
79
left.values().as_ref(),
80
right.values().as_ref(),
81
left.size(),
82
idxs,
83
)
84
}
85
86
// (List, Array) / (Array, List)
87
if let (Some(left), Some(right)) = (
88
left.downcast_ref::<ListArray<i32>>(),
89
right.downcast_ref::<FixedSizeListArray>(),
90
) {
91
find_validity_mismatch_list_fsl_impl(left, right, idxs);
92
}
93
if let (Some(left), Some(right)) = (
94
left.downcast_ref::<ListArray<i64>>(),
95
right.downcast_ref::<FixedSizeListArray>(),
96
) {
97
find_validity_mismatch_list_fsl_impl(left, right, idxs);
98
}
99
if let (Some(right), Some(left)) = (
100
left.downcast_ref::<FixedSizeListArray>(),
101
right.downcast_ref::<ListArray<i32>>(),
102
) {
103
find_validity_mismatch_list_fsl_impl(left, right, idxs);
104
}
105
if let (Some(right), Some(left)) = (
106
left.downcast_ref::<FixedSizeListArray>(),
107
right.downcast_ref::<ListArray<i64>>(),
108
) {
109
find_validity_mismatch_list_fsl_impl(left, right, idxs);
110
}
111
112
if pre_nesting_length == idxs.len() {
113
return;
114
}
115
idxs[original_idxs_length..].sort_unstable();
116
}
117
118
fn find_validity_mismatch_fsl_fsl_nested(
119
left: &dyn Array,
120
right: &dyn Array,
121
size: usize,
122
idxs: &mut Vec<IdxSize>,
123
) {
124
assert_eq!(left.len(), right.len());
125
let start_length = idxs.len();
126
find_validity_mismatch(left, right, idxs);
127
if idxs.len() > start_length {
128
let mut offset = 0;
129
idxs[start_length] /= size as IdxSize;
130
for i in start_length + 1..idxs.len() {
131
idxs[i - offset] = idxs[i] / size as IdxSize;
132
133
if idxs[i - offset] == idxs[i - offset - 1] {
134
offset += 1;
135
}
136
}
137
idxs.truncate(idxs.len() - offset);
138
}
139
}
140
141
fn find_validity_mismatch_list_list_nested<O: Offset>(
142
left: &ListArray<O>,
143
right: &ListArray<O>,
144
idxs: &mut Vec<IdxSize>,
145
) {
146
let mut nested_idxs = Vec::new();
147
find_validity_mismatch(
148
left.values().as_ref(),
149
right.values().as_ref(),
150
&mut nested_idxs,
151
);
152
153
if nested_idxs.is_empty() {
154
return;
155
}
156
157
assert_eq!(left.offsets().first().to_usize(), 0);
158
assert_eq!(left.offsets().range().to_usize(), left.values().len());
159
160
// @TODO: Optimize. This is only used on the error path so it is find, right?
161
let mut j = 0;
162
for (i, (start, length)) in left.offsets().offset_and_length_iter().enumerate_idx() {
163
if j < nested_idxs.len() && (nested_idxs[j] as usize) < start + length {
164
idxs.push(i);
165
j += 1;
166
167
// Loop over remaining items in same element.
168
while j < nested_idxs.len() && (nested_idxs[j] as usize) < start + length {
169
j += 1;
170
}
171
}
172
173
if j == nested_idxs.len() {
174
break;
175
}
176
}
177
}
178
179
fn find_validity_mismatch_list_fsl_impl<O: Offset>(
180
left: &ListArray<O>,
181
right: &FixedSizeListArray,
182
idxs: &mut Vec<IdxSize>,
183
) {
184
if left.validity().is_none() && right.validity().is_none() {
185
find_validity_mismatch_fsl_fsl_nested(
186
left.values().as_ref(),
187
right.values().as_ref(),
188
right.size(),
189
idxs,
190
);
191
return;
192
}
193
194
let (ArrowDataType::List(f) | ArrowDataType::LargeList(f)) = left.dtype() else {
195
unreachable!();
196
};
197
let left = crate::cast::cast_list_to_fixed_size_list(
198
left,
199
f,
200
right.size(),
201
CastOptionsImpl::default(),
202
)
203
.unwrap();
204
find_validity_mismatch_fsl_fsl_nested(
205
left.values().as_ref(),
206
right.values().as_ref(),
207
right.size(),
208
idxs,
209
)
210
}
211
212