Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/io/ipc/read/common.rs
8430 views
1
use std::collections::VecDeque;
2
use std::io::{Read, Seek};
3
use std::sync::Arc;
4
5
use polars_error::{PolarsResult, polars_bail, polars_err};
6
use polars_utils::aliases::PlHashMap;
7
use polars_utils::bool::UnsafeBool;
8
use polars_utils::pl_str::PlSmallStr;
9
10
use super::Dictionaries;
11
use super::deserialize::{read, skip};
12
use crate::array::*;
13
use crate::datatypes::{ArrowDataType, ArrowSchema, Field};
14
use crate::io::ipc::read::OutOfSpecKind;
15
use crate::io::ipc::{IpcField, IpcSchema};
16
use crate::record_batch::RecordBatchT;
17
18
#[derive(Debug, Eq, PartialEq, Hash)]
19
enum ProjectionResult<A> {
20
Selected(A),
21
NotSelected(A),
22
}
23
24
/// An iterator adapter that will return `Some(x)` or `None`
25
/// # Panics
26
/// The iterator panics iff the `projection` is not strictly increasing.
27
struct ProjectionIter<'a, A, I: Iterator<Item = A>> {
28
projection: &'a [usize],
29
iter: I,
30
current_count: usize,
31
current_projection: usize,
32
}
33
34
impl<'a, A, I: Iterator<Item = A>> ProjectionIter<'a, A, I> {
35
/// # Panics
36
/// iff `projection` is empty
37
pub fn new(projection: &'a [usize], iter: I) -> Self {
38
Self {
39
projection: &projection[1..],
40
iter,
41
current_count: 0,
42
current_projection: projection[0],
43
}
44
}
45
}
46
47
impl<A, I: Iterator<Item = A>> Iterator for ProjectionIter<'_, A, I> {
48
type Item = ProjectionResult<A>;
49
50
fn next(&mut self) -> Option<Self::Item> {
51
if let Some(item) = self.iter.next() {
52
let result = if self.current_count == self.current_projection {
53
if !self.projection.is_empty() {
54
assert!(self.projection[0] > self.current_projection);
55
self.current_projection = self.projection[0];
56
self.projection = &self.projection[1..];
57
} else {
58
self.current_projection = 0 // a value that most likely already passed
59
};
60
Some(ProjectionResult::Selected(item))
61
} else {
62
Some(ProjectionResult::NotSelected(item))
63
};
64
self.current_count += 1;
65
result
66
} else {
67
None
68
}
69
}
70
71
fn size_hint(&self) -> (usize, Option<usize>) {
72
self.iter.size_hint()
73
}
74
}
75
76
/// Returns a [`RecordBatchT`] from a reader.
77
/// # Panic
78
/// Panics iff the projection is not in increasing order (e.g. `[1, 0]` nor `[0, 1, 1]` are valid)
79
#[allow(clippy::too_many_arguments)]
80
pub fn read_record_batch<R: Read + Seek>(
81
batch: arrow_format::ipc::RecordBatchRef,
82
fields: &ArrowSchema,
83
ipc_schema: &IpcSchema,
84
projection: Option<&[usize]>,
85
limit: Option<usize>,
86
dictionaries: &Dictionaries,
87
version: arrow_format::ipc::MetadataVersion,
88
reader: &mut R,
89
block_offset: u64,
90
scratch: &mut Vec<u8>,
91
checked: UnsafeBool,
92
) -> PolarsResult<RecordBatchT<Box<dyn Array>>> {
93
assert_eq!(fields.len(), ipc_schema.fields.len());
94
let buffers = batch
95
.buffers()
96
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBuffers(err)))?
97
.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageBuffers))?;
98
let mut variadic_buffer_counts = batch
99
.variadic_buffer_counts()
100
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))?
101
.map(|v| v.iter().map(|v| v as usize).collect::<VecDeque<usize>>())
102
.unwrap_or_else(VecDeque::new);
103
let mut buffers: VecDeque<arrow_format::ipc::BufferRef> = buffers.iter().collect();
104
105
let field_nodes = batch
106
.nodes()
107
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferNodes(err)))?
108
.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageNodes))?;
109
let mut field_nodes = field_nodes.iter().collect::<VecDeque<_>>();
110
111
let columns = if let Some(projection) = projection {
112
let projection = ProjectionIter::new(
113
projection,
114
fields.iter_values().zip(ipc_schema.fields.iter()),
115
);
116
117
projection
118
.map(|maybe_field| match maybe_field {
119
ProjectionResult::Selected((field, ipc_field)) => Ok(Some(read(
120
&mut field_nodes,
121
&mut variadic_buffer_counts,
122
field,
123
ipc_field,
124
&mut buffers,
125
reader,
126
dictionaries,
127
block_offset,
128
ipc_schema.is_little_endian,
129
batch.compression().map_err(|err| {
130
polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err))
131
})?,
132
limit,
133
version,
134
scratch,
135
checked,
136
)?)),
137
ProjectionResult::NotSelected((field, _)) => {
138
skip(
139
&mut field_nodes,
140
&field.dtype,
141
&mut buffers,
142
&mut variadic_buffer_counts,
143
)?;
144
Ok(None)
145
},
146
})
147
.filter_map(|x| x.transpose())
148
.collect::<PolarsResult<Vec<_>>>()?
149
} else {
150
fields
151
.iter_values()
152
.zip(ipc_schema.fields.iter())
153
.map(|(field, ipc_field)| {
154
read(
155
&mut field_nodes,
156
&mut variadic_buffer_counts,
157
field,
158
ipc_field,
159
&mut buffers,
160
reader,
161
dictionaries,
162
block_offset,
163
ipc_schema.is_little_endian,
164
batch.compression().map_err(|err| {
165
polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err))
166
})?,
167
limit,
168
version,
169
scratch,
170
checked,
171
)
172
})
173
.collect::<PolarsResult<Vec<_>>>()?
174
};
175
176
let length = batch
177
.length()
178
.map_err(|_| polars_err!(oos = OutOfSpecKind::MissingData))
179
.unwrap()
180
.try_into()
181
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
182
let length = limit.map(|limit| limit.min(length)).unwrap_or(length);
183
184
let mut schema: ArrowSchema = fields.iter_values().cloned().collect();
185
if let Some(projection) = projection {
186
schema = schema.try_project_indices(projection).unwrap();
187
}
188
RecordBatchT::try_new(length, Arc::new(schema), columns)
189
}
190
191
fn find_first_dict_field_d<'a>(
192
id: i64,
193
dtype: &'a ArrowDataType,
194
ipc_field: &'a IpcField,
195
) -> Option<(&'a Field, &'a IpcField)> {
196
use ArrowDataType::*;
197
match dtype.to_storage() {
198
Dictionary(_, inner, _) => find_first_dict_field_d(id, inner.as_ref(), ipc_field),
199
List(field) | LargeList(field) | FixedSizeList(field, ..) | Map(field, ..) => {
200
find_first_dict_field(id, field.as_ref(), &ipc_field.fields[0])
201
},
202
Struct(fields) => {
203
for (field, ipc_field) in fields.iter().zip(ipc_field.fields.iter()) {
204
if let Some(f) = find_first_dict_field(id, field, ipc_field) {
205
return Some(f);
206
}
207
}
208
None
209
},
210
Union(u) => {
211
for (field, ipc_field) in u.fields.iter().zip(ipc_field.fields.iter()) {
212
if let Some(f) = find_first_dict_field(id, field, ipc_field) {
213
return Some(f);
214
}
215
}
216
None
217
},
218
_ => None,
219
}
220
}
221
222
fn find_first_dict_field<'a>(
223
id: i64,
224
field: &'a Field,
225
ipc_field: &'a IpcField,
226
) -> Option<(&'a Field, &'a IpcField)> {
227
if let Some(field_id) = ipc_field.dictionary_id {
228
if id == field_id {
229
return Some((field, ipc_field));
230
}
231
}
232
find_first_dict_field_d(id, &field.dtype, ipc_field)
233
}
234
235
pub(crate) fn first_dict_field<'a>(
236
id: i64,
237
fields: &'a ArrowSchema,
238
ipc_fields: &'a [IpcField],
239
) -> PolarsResult<(&'a Field, &'a IpcField)> {
240
assert_eq!(fields.len(), ipc_fields.len());
241
for (field, ipc_field) in fields.iter_values().zip(ipc_fields.iter()) {
242
if let Some(field) = find_first_dict_field(id, field, ipc_field) {
243
return Ok(field);
244
}
245
}
246
Err(polars_err!(
247
oos = OutOfSpecKind::InvalidId { requested_id: id }
248
))
249
}
250
251
/// Reads a dictionary from the reader,
252
/// updating `dictionaries` with the resulting dictionary
253
#[allow(clippy::too_many_arguments)]
254
pub fn read_dictionary<R: Read + Seek>(
255
batch: arrow_format::ipc::DictionaryBatchRef,
256
fields: &ArrowSchema,
257
ipc_schema: &IpcSchema,
258
dictionaries: &mut Dictionaries,
259
reader: &mut R,
260
block_offset: u64,
261
scratch: &mut Vec<u8>,
262
checked: UnsafeBool,
263
) -> PolarsResult<()> {
264
if batch
265
.is_delta()
266
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferIsDelta(err)))?
267
{
268
polars_bail!(ComputeError: "delta dictionary batches not supported")
269
}
270
271
let id = batch
272
.id()
273
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferId(err)))?;
274
let (first_field, first_ipc_field) = first_dict_field(id, fields, &ipc_schema.fields)?;
275
276
let batch = batch
277
.data()
278
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferData(err)))?
279
.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingData))?;
280
281
let value_type =
282
if let ArrowDataType::Dictionary(_, value_type, _) = first_field.dtype.to_storage() {
283
value_type.as_ref()
284
} else {
285
polars_bail!(oos = OutOfSpecKind::InvalidIdDataType { requested_id: id })
286
};
287
288
// Make a fake schema for the dictionary batch.
289
let fields = std::iter::once((
290
PlSmallStr::EMPTY,
291
Field::new(PlSmallStr::EMPTY, value_type.clone(), false),
292
))
293
.collect();
294
let ipc_schema = IpcSchema {
295
fields: vec![first_ipc_field.clone()],
296
is_little_endian: ipc_schema.is_little_endian,
297
};
298
let chunk = read_record_batch(
299
batch,
300
&fields,
301
&ipc_schema,
302
None,
303
None, // we must read the whole dictionary
304
dictionaries,
305
arrow_format::ipc::MetadataVersion::V5,
306
reader,
307
block_offset,
308
scratch,
309
checked,
310
)?;
311
312
dictionaries.insert(id, chunk.into_arrays().pop().unwrap());
313
314
Ok(())
315
}
316
317
#[derive(Clone)]
318
pub struct ProjectionInfo {
319
pub columns: Vec<usize>,
320
pub map: PlHashMap<usize, usize>,
321
pub schema: ArrowSchema,
322
}
323
324
pub fn prepare_projection(schema: &ArrowSchema, mut projection: Vec<usize>) -> ProjectionInfo {
325
let schema = projection
326
.iter()
327
.map(|x| {
328
let (k, v) = schema.get_at_index(*x).unwrap();
329
(k.clone(), v.clone())
330
})
331
.collect();
332
333
// todo: find way to do this more efficiently
334
let mut indices = (0..projection.len()).collect::<Vec<_>>();
335
indices.sort_unstable_by_key(|&i| &projection[i]);
336
let map = indices.iter().copied().enumerate().fold(
337
PlHashMap::default(),
338
|mut acc, (index, new_index)| {
339
acc.insert(index, new_index);
340
acc
341
},
342
);
343
projection.sort_unstable();
344
345
// check unique
346
if !projection.is_empty() {
347
let mut previous = projection[0];
348
349
for &i in &projection[1..] {
350
assert!(
351
previous < i,
352
"The projection on IPC must not contain duplicates"
353
);
354
previous = i;
355
}
356
}
357
358
ProjectionInfo {
359
columns: projection,
360
map,
361
schema,
362
}
363
}
364
365
pub fn apply_projection(
366
chunk: RecordBatchT<Box<dyn Array>>,
367
map: &PlHashMap<usize, usize>,
368
) -> RecordBatchT<Box<dyn Array>> {
369
let length = chunk.len();
370
371
// re-order according to projection
372
let (schema, arrays) = chunk.into_schema_and_arrays();
373
let mut new_schema = schema.as_ref().clone();
374
let mut new_arrays = arrays.clone();
375
376
map.iter().for_each(|(old, new)| {
377
let (old_name, old_field) = schema.get_at_index(*old).unwrap();
378
let (new_name, new_field) = new_schema.get_at_index_mut(*new).unwrap();
379
380
*new_name = old_name.clone();
381
*new_field = old_field.clone();
382
383
new_arrays[*new] = arrays[*old].clone();
384
});
385
386
RecordBatchT::new(length, Arc::new(new_schema), new_arrays)
387
}
388
389
#[cfg(test)]
390
mod tests {
391
use super::*;
392
393
#[test]
394
fn project_iter() {
395
let iter = 1..6;
396
let iter = ProjectionIter::new(&[0, 2, 4], iter);
397
let result: Vec<_> = iter.collect();
398
use ProjectionResult::*;
399
assert_eq!(
400
result,
401
vec![
402
Selected(1),
403
NotSelected(2),
404
Selected(3),
405
NotSelected(4),
406
Selected(5)
407
]
408
)
409
}
410
}
411
412