Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/io/ipc/write/common.rs
6940 views
1
use std::borrow::{Borrow, Cow};
2
3
use arrow_format::ipc;
4
use arrow_format::ipc::planus::Builder;
5
use polars_error::{PolarsResult, polars_bail, polars_err};
6
7
use super::super::IpcField;
8
use super::{write, write_dictionary};
9
use crate::array::*;
10
use crate::datatypes::*;
11
use crate::io::ipc::endianness::is_native_little_endian;
12
use crate::io::ipc::read::Dictionaries;
13
use crate::legacy::prelude::LargeListArray;
14
use crate::match_integer_type;
15
use crate::record_batch::RecordBatchT;
16
use crate::types::Index;
17
18
/// Compression codec
19
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20
pub enum Compression {
21
/// LZ4 (framed)
22
LZ4,
23
/// ZSTD
24
ZSTD,
25
}
26
27
/// Options declaring the behaviour of writing to IPC
28
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
29
pub struct WriteOptions {
30
/// Whether the buffers should be compressed and which codec to use.
31
/// Note: to use compression the crate must be compiled with feature `io_ipc_compression`.
32
pub compression: Option<Compression>,
33
}
34
35
/// Find the dictionary that are new and need to be encoded.
36
pub fn dictionaries_to_encode(
37
field: &IpcField,
38
array: &dyn Array,
39
dictionary_tracker: &mut DictionaryTracker,
40
dicts_to_encode: &mut Vec<(i64, Box<dyn Array>)>,
41
) -> PolarsResult<()> {
42
use PhysicalType::*;
43
match array.dtype().to_physical_type() {
44
Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
45
| FixedSizeBinary | BinaryView | Utf8View => Ok(()),
46
Dictionary(key_type) => match_integer_type!(key_type, |$T| {
47
let dict_id = field.dictionary_id
48
.ok_or_else(|| polars_err!(InvalidOperation: "Dictionaries must have an associated id"))?;
49
50
if dictionary_tracker.insert(dict_id, array)? {
51
dicts_to_encode.push((dict_id, array.to_boxed()));
52
}
53
54
let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
55
let values = array.values();
56
// @Q? Should this not pick fields[0]?
57
dictionaries_to_encode(field,
58
values.as_ref(),
59
dictionary_tracker,
60
dicts_to_encode,
61
)?;
62
63
Ok(())
64
}),
65
Struct => {
66
let array = array.as_any().downcast_ref::<StructArray>().unwrap();
67
let fields = field.fields.as_slice();
68
if array.fields().len() != fields.len() {
69
polars_bail!(InvalidOperation:
70
"The number of fields in a struct must equal the number of children in IpcField".to_string(),
71
);
72
}
73
fields
74
.iter()
75
.zip(array.values().iter())
76
.try_for_each(|(field, values)| {
77
dictionaries_to_encode(
78
field,
79
values.as_ref(),
80
dictionary_tracker,
81
dicts_to_encode,
82
)
83
})
84
},
85
List => {
86
let values = array
87
.as_any()
88
.downcast_ref::<ListArray<i32>>()
89
.unwrap()
90
.values();
91
let field = &field.fields[0]; // todo: error instead
92
dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
93
},
94
LargeList => {
95
let values = array
96
.as_any()
97
.downcast_ref::<ListArray<i64>>()
98
.unwrap()
99
.values();
100
let field = &field.fields[0]; // todo: error instead
101
dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
102
},
103
FixedSizeList => {
104
let values = array
105
.as_any()
106
.downcast_ref::<FixedSizeListArray>()
107
.unwrap()
108
.values();
109
let field = &field.fields[0]; // todo: error instead
110
dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
111
},
112
Union => {
113
let values = array
114
.as_any()
115
.downcast_ref::<UnionArray>()
116
.unwrap()
117
.fields();
118
let fields = &field.fields[..]; // todo: error instead
119
if values.len() != fields.len() {
120
polars_bail!(InvalidOperation:
121
"The number of fields in a union must equal the number of children in IpcField"
122
);
123
}
124
fields
125
.iter()
126
.zip(values.iter())
127
.try_for_each(|(field, values)| {
128
dictionaries_to_encode(
129
field,
130
values.as_ref(),
131
dictionary_tracker,
132
dicts_to_encode,
133
)
134
})
135
},
136
Map => {
137
let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
138
let field = &field.fields[0]; // todo: error instead
139
dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
140
},
141
}
142
}
143
144
/// Encode a dictionary array with a certain id.
145
///
146
/// # Panics
147
///
148
/// This will panic if the given array is not a [`DictionaryArray`].
149
pub fn encode_dictionary(
150
dict_id: i64,
151
array: &dyn Array,
152
options: &WriteOptions,
153
encoded_dictionaries: &mut Vec<EncodedData>,
154
) -> PolarsResult<()> {
155
let PhysicalType::Dictionary(key_type) = array.dtype().to_physical_type() else {
156
panic!("Given array is not a DictionaryArray")
157
};
158
159
match_integer_type!(key_type, |$T| {
160
let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
161
encoded_dictionaries.push(dictionary_batch_to_bytes::<$T>(
162
dict_id,
163
array,
164
options,
165
is_native_little_endian(),
166
));
167
});
168
169
Ok(())
170
}
171
172
pub fn encode_new_dictionaries(
173
field: &IpcField,
174
array: &dyn Array,
175
options: &WriteOptions,
176
dictionary_tracker: &mut DictionaryTracker,
177
encoded_dictionaries: &mut Vec<EncodedData>,
178
) -> PolarsResult<()> {
179
let mut dicts_to_encode = Vec::new();
180
dictionaries_to_encode(field, array, dictionary_tracker, &mut dicts_to_encode)?;
181
for (dict_id, dict_array) in dicts_to_encode {
182
encode_dictionary(dict_id, dict_array.as_ref(), options, encoded_dictionaries)?;
183
}
184
Ok(())
185
}
186
187
pub fn encode_chunk(
188
chunk: &RecordBatchT<Box<dyn Array>>,
189
fields: &[IpcField],
190
dictionary_tracker: &mut DictionaryTracker,
191
options: &WriteOptions,
192
) -> PolarsResult<(Vec<EncodedData>, EncodedData)> {
193
let mut encoded_message = EncodedData::default();
194
let encoded_dictionaries = encode_chunk_amortized(
195
chunk,
196
fields,
197
dictionary_tracker,
198
options,
199
&mut encoded_message,
200
)?;
201
Ok((encoded_dictionaries, encoded_message))
202
}
203
204
// Amortizes `EncodedData` allocation.
205
pub fn encode_chunk_amortized(
206
chunk: &RecordBatchT<Box<dyn Array>>,
207
fields: &[IpcField],
208
dictionary_tracker: &mut DictionaryTracker,
209
options: &WriteOptions,
210
encoded_message: &mut EncodedData,
211
) -> PolarsResult<Vec<EncodedData>> {
212
let mut encoded_dictionaries = vec![];
213
214
for (field, array) in fields.iter().zip(chunk.as_ref()) {
215
encode_new_dictionaries(
216
field,
217
array.as_ref(),
218
options,
219
dictionary_tracker,
220
&mut encoded_dictionaries,
221
)?;
222
}
223
encode_record_batch(chunk, options, encoded_message);
224
225
Ok(encoded_dictionaries)
226
}
227
228
fn serialize_compression(
229
compression: Option<Compression>,
230
) -> Option<Box<arrow_format::ipc::BodyCompression>> {
231
if let Some(compression) = compression {
232
let codec = match compression {
233
Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,
234
Compression::ZSTD => arrow_format::ipc::CompressionType::Zstd,
235
};
236
Some(Box::new(arrow_format::ipc::BodyCompression {
237
codec,
238
method: arrow_format::ipc::BodyCompressionMethod::Buffer,
239
}))
240
} else {
241
None
242
}
243
}
244
245
fn set_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {
246
match array.dtype() {
247
ArrowDataType::Utf8View => {
248
let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
249
counts.push(array.data_buffers().len() as i64);
250
},
251
ArrowDataType::BinaryView => {
252
let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
253
counts.push(array.data_buffers().len() as i64);
254
},
255
ArrowDataType::Struct(_) => {
256
let array = array.as_any().downcast_ref::<StructArray>().unwrap();
257
for array in array.values() {
258
set_variadic_buffer_counts(counts, array.as_ref())
259
}
260
},
261
ArrowDataType::LargeList(_) => {
262
// Subslicing can change the variadic buffer count, so we have to
263
// slice here as well to stay synchronized.
264
let array = array.as_any().downcast_ref::<LargeListArray>().unwrap();
265
let offsets = array.offsets().buffer();
266
let first = *offsets.first().unwrap();
267
let last = *offsets.last().unwrap();
268
let subslice = array
269
.values()
270
.sliced(first.to_usize(), last.to_usize() - first.to_usize());
271
set_variadic_buffer_counts(counts, &*subslice)
272
},
273
ArrowDataType::FixedSizeList(_, _) => {
274
let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
275
set_variadic_buffer_counts(counts, array.values().as_ref())
276
},
277
// Don't traverse dictionary values as those are set when the `Dictionary` IPC struct
278
// is read.
279
ArrowDataType::Dictionary(_, _, _) => (),
280
_ => (),
281
}
282
}
283
284
fn gc_bin_view<'a, T: ViewType + ?Sized>(
285
arr: &'a Box<dyn Array>,
286
concrete_arr: &'a BinaryViewArrayGeneric<T>,
287
) -> Cow<'a, Box<dyn Array>> {
288
let bytes_len = concrete_arr.total_bytes_len();
289
let buffer_len = concrete_arr.total_buffer_len();
290
let extra_len = buffer_len.saturating_sub(bytes_len);
291
if extra_len < bytes_len.min(1024) {
292
// We can afford some tiny waste.
293
Cow::Borrowed(arr)
294
} else {
295
// Force GC it.
296
Cow::Owned(concrete_arr.clone().gc().boxed())
297
}
298
}
299
300
pub fn encode_array(
301
array: &Box<dyn Array>,
302
options: &WriteOptions,
303
variadic_buffer_counts: &mut Vec<i64>,
304
buffers: &mut Vec<ipc::Buffer>,
305
arrow_data: &mut Vec<u8>,
306
nodes: &mut Vec<ipc::FieldNode>,
307
offset: &mut i64,
308
) {
309
// We don't want to write all buffers in sliced arrays.
310
let array = match array.dtype() {
311
ArrowDataType::BinaryView => {
312
let concrete_arr = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
313
gc_bin_view(array, concrete_arr)
314
},
315
ArrowDataType::Utf8View => {
316
let concrete_arr = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
317
gc_bin_view(array, concrete_arr)
318
},
319
_ => Cow::Borrowed(array),
320
};
321
let array = array.as_ref().as_ref();
322
323
set_variadic_buffer_counts(variadic_buffer_counts, array);
324
325
write(
326
array,
327
buffers,
328
arrow_data,
329
nodes,
330
offset,
331
is_native_little_endian(),
332
options.compression,
333
)
334
}
335
336
/// Write [`RecordBatchT`] into two sets of bytes, one for the header (ipc::Schema::Message) and the
337
/// other for the batch's data
338
pub fn encode_record_batch(
339
chunk: &RecordBatchT<Box<dyn Array>>,
340
options: &WriteOptions,
341
encoded_message: &mut EncodedData,
342
) {
343
let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
344
let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
345
encoded_message.arrow_data.clear();
346
347
let mut offset = 0;
348
let mut variadic_buffer_counts = vec![];
349
for array in chunk.arrays() {
350
encode_array(
351
array,
352
options,
353
&mut variadic_buffer_counts,
354
&mut buffers,
355
&mut encoded_message.arrow_data,
356
&mut nodes,
357
&mut offset,
358
);
359
}
360
361
commit_encoded_arrays(
362
chunk.len(),
363
options,
364
variadic_buffer_counts,
365
buffers,
366
nodes,
367
encoded_message,
368
);
369
}
370
371
pub fn commit_encoded_arrays(
372
array_len: usize,
373
options: &WriteOptions,
374
variadic_buffer_counts: Vec<i64>,
375
buffers: Vec<ipc::Buffer>,
376
nodes: Vec<ipc::FieldNode>,
377
encoded_message: &mut EncodedData,
378
) {
379
let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
380
None
381
} else {
382
Some(variadic_buffer_counts)
383
};
384
385
let compression = serialize_compression(options.compression);
386
387
let message = arrow_format::ipc::Message {
388
version: arrow_format::ipc::MetadataVersion::V5,
389
header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(
390
arrow_format::ipc::RecordBatch {
391
length: array_len as i64,
392
nodes: Some(nodes),
393
buffers: Some(buffers),
394
compression,
395
variadic_buffer_counts,
396
},
397
))),
398
body_length: encoded_message.arrow_data.len() as i64,
399
custom_metadata: None,
400
};
401
402
let mut builder = Builder::new();
403
let ipc_message = builder.finish(&message, None);
404
encoded_message.ipc_message = ipc_message.to_vec();
405
}
406
407
/// Write dictionary values into two sets of bytes, one for the header (ipc::Schema::Message) and the
408
/// other for the data
409
fn dictionary_batch_to_bytes<K: DictionaryKey>(
410
dict_id: i64,
411
array: &DictionaryArray<K>,
412
options: &WriteOptions,
413
is_little_endian: bool,
414
) -> EncodedData {
415
let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
416
let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
417
let mut arrow_data: Vec<u8> = vec![];
418
let mut variadic_buffer_counts = vec![];
419
set_variadic_buffer_counts(&mut variadic_buffer_counts, array.values().as_ref());
420
421
let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
422
None
423
} else {
424
Some(variadic_buffer_counts)
425
};
426
427
let length = write_dictionary(
428
array,
429
&mut buffers,
430
&mut arrow_data,
431
&mut nodes,
432
&mut 0,
433
is_little_endian,
434
options.compression,
435
false,
436
);
437
438
let compression = serialize_compression(options.compression);
439
440
let message = arrow_format::ipc::Message {
441
version: arrow_format::ipc::MetadataVersion::V5,
442
header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(
443
arrow_format::ipc::DictionaryBatch {
444
id: dict_id,
445
data: Some(Box::new(arrow_format::ipc::RecordBatch {
446
length: length as i64,
447
nodes: Some(nodes),
448
buffers: Some(buffers),
449
compression,
450
variadic_buffer_counts,
451
})),
452
is_delta: false,
453
},
454
))),
455
body_length: arrow_data.len() as i64,
456
custom_metadata: None,
457
};
458
459
let mut builder = Builder::new();
460
let ipc_message = builder.finish(&message, None);
461
462
EncodedData {
463
ipc_message: ipc_message.to_vec(),
464
arrow_data,
465
}
466
}
467
468
/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary
469
/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which
470
/// isn't allowed in the `FileWriter`.
471
pub struct DictionaryTracker {
472
pub dictionaries: Dictionaries,
473
pub cannot_replace: bool,
474
}
475
476
impl DictionaryTracker {
477
/// Keep track of the dictionary with the given ID and values. Behavior:
478
///
479
/// * If this ID has been written already and has the same data, return `Ok(false)` to indicate
480
/// that the dictionary was not actually inserted (because it's already been seen).
481
/// * If this ID has been written already but with different data, and this tracker is
482
/// configured to return an error, return an error.
483
/// * If the tracker has not been configured to error on replacement or this dictionary
484
/// has never been seen before, return `Ok(true)` to indicate that the dictionary was just
485
/// inserted.
486
pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult<bool> {
487
let values = match array.dtype() {
488
ArrowDataType::Dictionary(key_type, _, _) => {
489
match_integer_type!(key_type, |$T| {
490
let array = array
491
.as_any()
492
.downcast_ref::<DictionaryArray<$T>>()
493
.unwrap();
494
array.values()
495
})
496
},
497
_ => unreachable!(),
498
};
499
500
// If a dictionary with this id was already emitted, check if it was the same.
501
if let Some(last) = self.dictionaries.get(&dict_id) {
502
if last.as_ref() == values.as_ref() {
503
// Same dictionary values => no need to emit it again
504
return Ok(false);
505
} else if self.cannot_replace {
506
polars_bail!(InvalidOperation:
507
"Dictionary replacement detected when writing IPC file format. \
508
Arrow IPC files only support a single dictionary for a given field \
509
across all batches."
510
);
511
}
512
};
513
514
self.dictionaries.insert(dict_id, values.clone());
515
Ok(true)
516
}
517
}
518
519
/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data
520
#[derive(Debug, Default)]
521
pub struct EncodedData {
522
/// An encoded ipc::Schema::Message
523
pub ipc_message: Vec<u8>,
524
/// Arrow buffers to be written, should be an empty vec for schema messages
525
pub arrow_data: Vec<u8>,
526
}
527
528
/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes
529
#[inline]
530
pub(crate) fn pad_to_64(len: usize) -> usize {
531
((len + 63) & !63) - len
532
}
533
534
/// An array [`RecordBatchT`] with optional accompanying IPC fields.
535
#[derive(Debug, Clone, PartialEq)]
536
pub struct Record<'a> {
537
columns: Cow<'a, RecordBatchT<Box<dyn Array>>>,
538
fields: Option<Cow<'a, [IpcField]>>,
539
}
540
541
impl Record<'_> {
542
/// Get the IPC fields for this record.
543
pub fn fields(&self) -> Option<&[IpcField]> {
544
self.fields.as_deref()
545
}
546
547
/// Get the Arrow columns in this record.
548
pub fn columns(&self) -> &RecordBatchT<Box<dyn Array>> {
549
self.columns.borrow()
550
}
551
}
552
553
impl From<RecordBatchT<Box<dyn Array>>> for Record<'static> {
554
fn from(columns: RecordBatchT<Box<dyn Array>>) -> Self {
555
Self {
556
columns: Cow::Owned(columns),
557
fields: None,
558
}
559
}
560
}
561
562
impl<'a, F> From<(RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
563
where
564
F: Into<Cow<'a, [IpcField]>>,
565
{
566
fn from((columns, fields): (RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
567
Self {
568
columns: Cow::Owned(columns),
569
fields: fields.map(|f| f.into()),
570
}
571
}
572
}
573
574
impl<'a, F> From<(&'a RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
575
where
576
F: Into<Cow<'a, [IpcField]>>,
577
{
578
fn from((columns, fields): (&'a RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
579
Self {
580
columns: Cow::Borrowed(columns),
581
fields: fields.map(|f| f.into()),
582
}
583
}
584
}
585
586