Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/io/ipc/write/common.rs
8424 views
1
use std::borrow::{Borrow, Cow};
2
3
use arrow_format::ipc;
4
use arrow_format::ipc::KeyValue;
5
use arrow_format::ipc::planus::Builder;
6
use bytes::Bytes;
7
use polars_error::{PolarsResult, polars_bail, polars_err};
8
use polars_utils::compression::ZstdLevel;
9
10
use super::super::IpcField;
11
use super::write;
12
use crate::array::*;
13
use crate::datatypes::*;
14
use crate::io::ipc::endianness::is_native_little_endian;
15
use crate::io::ipc::read::Dictionaries;
16
use crate::legacy::prelude::LargeListArray;
17
use crate::match_integer_type;
18
use crate::record_batch::RecordBatchT;
19
use crate::types::Index;
20
21
/// Compression codec
22
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23
pub enum Compression {
24
/// LZ4 (framed)
25
LZ4,
26
/// ZSTD
27
ZSTD(ZstdLevel),
28
}
29
30
/// Options declaring the behaviour of writing to IPC
31
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
32
pub struct WriteOptions {
33
/// Whether the buffers should be compressed and which codec to use.
34
/// Note: to use compression the crate must be compiled with feature `io_ipc_compression`.
35
pub compression: Option<Compression>,
36
}
37
38
/// Find the dictionary that are new and need to be encoded.
39
pub fn dictionaries_to_encode(
40
field: &IpcField,
41
array: &dyn Array,
42
dictionary_tracker: &mut DictionaryTracker,
43
dicts_to_encode: &mut Vec<(i64, Box<dyn Array>)>,
44
) -> PolarsResult<()> {
45
use PhysicalType::*;
46
match array.dtype().to_physical_type() {
47
Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
48
| FixedSizeBinary | BinaryView | Utf8View => Ok(()),
49
Dictionary(key_type) => match_integer_type!(key_type, |$T| {
50
let dict_id = field.dictionary_id
51
.ok_or_else(|| polars_err!(InvalidOperation: "Dictionaries must have an associated id"))?;
52
53
if dictionary_tracker.insert(dict_id, array)? {
54
dicts_to_encode.push((dict_id, array.to_boxed()));
55
}
56
57
let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();
58
let values = array.values();
59
// @Q? Should this not pick fields[0]?
60
dictionaries_to_encode(field,
61
values.as_ref(),
62
dictionary_tracker,
63
dicts_to_encode,
64
)?;
65
66
Ok(())
67
}),
68
Struct => {
69
let array = array.as_any().downcast_ref::<StructArray>().unwrap();
70
let fields = field.fields.as_slice();
71
if array.fields().len() != fields.len() {
72
polars_bail!(InvalidOperation: "The number of fields in a struct must equal the number of children in IpcField");
73
}
74
fields
75
.iter()
76
.zip(array.values().iter())
77
.try_for_each(|(field, values)| {
78
dictionaries_to_encode(
79
field,
80
values.as_ref(),
81
dictionary_tracker,
82
dicts_to_encode,
83
)
84
})
85
},
86
List => {
87
let values = array
88
.as_any()
89
.downcast_ref::<ListArray<i32>>()
90
.unwrap()
91
.values();
92
let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;
93
dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
94
},
95
LargeList => {
96
let values = array
97
.as_any()
98
.downcast_ref::<ListArray<i64>>()
99
.unwrap()
100
.values();
101
let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;
102
dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
103
},
104
FixedSizeList => {
105
let values = array
106
.as_any()
107
.downcast_ref::<FixedSizeListArray>()
108
.unwrap()
109
.values();
110
let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;
111
dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
112
},
113
Union => {
114
let values = array
115
.as_any()
116
.downcast_ref::<UnionArray>()
117
.unwrap()
118
.fields();
119
let fields = field.fields.as_slice();
120
if values.len() != fields.len() {
121
polars_bail!(InvalidOperation:
122
"The number of fields in a union must equal the number of children in IpcField"
123
);
124
}
125
fields
126
.iter()
127
.zip(values.iter())
128
.try_for_each(|(field, values)| {
129
dictionaries_to_encode(
130
field,
131
values.as_ref(),
132
dictionary_tracker,
133
dicts_to_encode,
134
)
135
})
136
},
137
Map => {
138
let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();
139
let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;
140
dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)
141
},
142
}
143
}
144
145
/// Encode a dictionary array with a certain id.
146
///
147
/// # Panics
148
///
149
/// This will panic if the given array is not a [`DictionaryArray`].
150
pub fn encode_dictionary(
151
dict_id: i64,
152
array: &dyn Array,
153
options: &WriteOptions,
154
) -> PolarsResult<EncodedData> {
155
let PhysicalType::Dictionary(key_type) = array.dtype().to_physical_type() else {
156
panic!("Given array is not a DictionaryArray")
157
};
158
159
match_integer_type!(key_type, |$T| {
160
let array: &DictionaryArray<$T> = array.as_any().downcast_ref().unwrap();
161
162
encode_dictionary_values(dict_id, array.values().as_ref(), options)
163
})
164
}
165
166
pub fn encode_new_dictionaries(
167
field: &IpcField,
168
array: &dyn Array,
169
options: &WriteOptions,
170
dictionary_tracker: &mut DictionaryTracker,
171
encoded_dictionaries: &mut Vec<EncodedData>,
172
) -> PolarsResult<()> {
173
let mut dicts_to_encode = Vec::new();
174
dictionaries_to_encode(field, array, dictionary_tracker, &mut dicts_to_encode)?;
175
for (dict_id, dict_array) in dicts_to_encode {
176
encoded_dictionaries.push(encode_dictionary(dict_id, dict_array.as_ref(), options)?);
177
}
178
Ok(())
179
}
180
181
pub fn encode_chunk(
182
chunk: &RecordBatchT<Box<dyn Array>>,
183
fields: &[IpcField],
184
dictionary_tracker: &mut DictionaryTracker,
185
options: &WriteOptions,
186
) -> PolarsResult<(Vec<EncodedData>, EncodedData)> {
187
let mut encoded_message = EncodedData::default();
188
let encoded_dictionaries = encode_chunk_amortized(
189
chunk,
190
fields,
191
dictionary_tracker,
192
options,
193
&mut encoded_message,
194
)?;
195
Ok((encoded_dictionaries, encoded_message))
196
}
197
198
// Amortizes `EncodedData` allocation.
199
pub fn encode_chunk_amortized(
200
chunk: &RecordBatchT<Box<dyn Array>>,
201
fields: &[IpcField],
202
dictionary_tracker: &mut DictionaryTracker,
203
options: &WriteOptions,
204
encoded_message: &mut EncodedData,
205
) -> PolarsResult<Vec<EncodedData>> {
206
let mut encoded_dictionaries = vec![];
207
208
for (field, array) in fields.iter().zip(chunk.as_ref()) {
209
encode_new_dictionaries(
210
field,
211
array.as_ref(),
212
options,
213
dictionary_tracker,
214
&mut encoded_dictionaries,
215
)?;
216
}
217
encode_record_batch(chunk, options, encoded_message);
218
219
Ok(encoded_dictionaries)
220
}
221
222
fn serialize_compression(
223
compression: Option<Compression>,
224
) -> Option<Box<arrow_format::ipc::BodyCompression>> {
225
if let Some(compression) = compression {
226
let codec = match compression {
227
Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,
228
Compression::ZSTD(_) => arrow_format::ipc::CompressionType::Zstd,
229
};
230
Some(Box::new(arrow_format::ipc::BodyCompression {
231
codec,
232
method: arrow_format::ipc::BodyCompressionMethod::Buffer,
233
}))
234
} else {
235
None
236
}
237
}
238
239
fn set_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {
240
match array.dtype().to_storage() {
241
ArrowDataType::Utf8View => {
242
let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
243
counts.push(array.data_buffers().len() as i64);
244
},
245
ArrowDataType::BinaryView => {
246
let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
247
counts.push(array.data_buffers().len() as i64);
248
},
249
ArrowDataType::Struct(_) => {
250
let array = array.as_any().downcast_ref::<StructArray>().unwrap();
251
for array in array.values() {
252
set_variadic_buffer_counts(counts, array.as_ref())
253
}
254
},
255
ArrowDataType::LargeList(_) => {
256
// Subslicing can change the variadic buffer count, so we have to
257
// slice here as well to stay synchronized.
258
let array = array.as_any().downcast_ref::<LargeListArray>().unwrap();
259
let offsets = array.offsets().buffer();
260
let first = *offsets.first().unwrap();
261
let last = *offsets.last().unwrap();
262
let subslice = array
263
.values()
264
.sliced(first.to_usize(), last.to_usize() - first.to_usize());
265
set_variadic_buffer_counts(counts, &*subslice)
266
},
267
ArrowDataType::FixedSizeList(_, _) => {
268
let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
269
set_variadic_buffer_counts(counts, array.values().as_ref())
270
},
271
// Don't traverse dictionary values as those are set when the `Dictionary` IPC struct
272
// is read.
273
ArrowDataType::Dictionary(_, _, _) => (),
274
_ => (),
275
}
276
}
277
278
fn gc_bin_view<'a, T: ViewType + ?Sized>(
279
arr: &'a Box<dyn Array>,
280
concrete_arr: &'a BinaryViewArrayGeneric<T>,
281
) -> Cow<'a, Box<dyn Array>> {
282
let bytes_len = concrete_arr.total_bytes_len();
283
let buffer_len = concrete_arr.total_buffer_len();
284
let extra_len = buffer_len.saturating_sub(bytes_len);
285
if extra_len < bytes_len.min(1024) {
286
// We can afford some tiny waste.
287
Cow::Borrowed(arr)
288
} else {
289
// Force GC it.
290
Cow::Owned(concrete_arr.clone().gc().boxed())
291
}
292
}
293
294
pub fn encode_array(
295
array: &Box<dyn Array>,
296
options: &WriteOptions,
297
variadic_buffer_counts: &mut Vec<i64>,
298
buffers: &mut Vec<ipc::Buffer>,
299
arrow_data: &mut Vec<u8>,
300
nodes: &mut Vec<ipc::FieldNode>,
301
offset: &mut i64,
302
) {
303
// We don't want to write all buffers in sliced arrays.
304
let array = match array.dtype() {
305
ArrowDataType::BinaryView => {
306
let concrete_arr = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
307
gc_bin_view(array, concrete_arr)
308
},
309
ArrowDataType::Utf8View => {
310
let concrete_arr = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
311
gc_bin_view(array, concrete_arr)
312
},
313
_ => Cow::Borrowed(array),
314
};
315
let array = array.as_ref().as_ref();
316
317
set_variadic_buffer_counts(variadic_buffer_counts, array);
318
319
write(
320
array,
321
buffers,
322
arrow_data,
323
nodes,
324
offset,
325
is_native_little_endian(),
326
options.compression,
327
)
328
}
329
330
/// Write [`RecordBatchT`] into two sets of bytes, one for the header (ipc::Schema::Message) and the
331
/// other for the batch's data
332
pub fn encode_record_batch(
333
chunk: &RecordBatchT<Box<dyn Array>>,
334
options: &WriteOptions,
335
encoded_message: &mut EncodedData,
336
) {
337
let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
338
let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
339
encoded_message.arrow_data.clear();
340
341
let mut offset = 0;
342
let mut variadic_buffer_counts = vec![];
343
for array in chunk.arrays() {
344
encode_array(
345
array,
346
options,
347
&mut variadic_buffer_counts,
348
&mut buffers,
349
&mut encoded_message.arrow_data,
350
&mut nodes,
351
&mut offset,
352
);
353
}
354
355
commit_encoded_arrays(
356
chunk.len(),
357
options,
358
variadic_buffer_counts,
359
buffers,
360
nodes,
361
None,
362
encoded_message,
363
);
364
}
365
366
pub fn commit_encoded_arrays(
367
array_len: usize,
368
options: &WriteOptions,
369
variadic_buffer_counts: Vec<i64>,
370
buffers: Vec<ipc::Buffer>,
371
nodes: Vec<ipc::FieldNode>,
372
custom_metadata: Option<Vec<KeyValue>>,
373
encoded_message: &mut EncodedData,
374
) {
375
let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
376
None
377
} else {
378
Some(variadic_buffer_counts)
379
};
380
381
let compression = serialize_compression(options.compression);
382
383
let message = arrow_format::ipc::Message {
384
version: arrow_format::ipc::MetadataVersion::V5,
385
header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(
386
arrow_format::ipc::RecordBatch {
387
length: array_len as i64,
388
nodes: Some(nodes),
389
buffers: Some(buffers),
390
compression,
391
variadic_buffer_counts,
392
},
393
))),
394
body_length: encoded_message.arrow_data.len() as i64,
395
custom_metadata,
396
};
397
398
let mut builder = Builder::new();
399
let ipc_message = builder.finish(&message, None);
400
encoded_message.ipc_message = ipc_message.to_vec();
401
}
402
403
pub fn encode_dictionary_values(
404
dict_id: i64,
405
values_array: &dyn Array,
406
options: &WriteOptions,
407
) -> PolarsResult<EncodedData> {
408
let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
409
let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
410
let mut arrow_data: Vec<u8> = vec![];
411
let mut variadic_buffer_counts = vec![];
412
set_variadic_buffer_counts(&mut variadic_buffer_counts, values_array);
413
414
let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
415
None
416
} else {
417
Some(variadic_buffer_counts)
418
};
419
420
write(
421
values_array,
422
&mut buffers,
423
&mut arrow_data,
424
&mut nodes,
425
&mut 0,
426
is_native_little_endian(),
427
options.compression,
428
);
429
430
let compression = serialize_compression(options.compression);
431
432
let message = arrow_format::ipc::Message {
433
version: arrow_format::ipc::MetadataVersion::V5,
434
header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(
435
arrow_format::ipc::DictionaryBatch {
436
id: dict_id,
437
data: Some(Box::new(arrow_format::ipc::RecordBatch {
438
length: values_array.len() as i64,
439
nodes: Some(nodes),
440
buffers: Some(buffers),
441
compression,
442
variadic_buffer_counts,
443
})),
444
is_delta: false,
445
},
446
))),
447
body_length: arrow_data.len() as i64,
448
custom_metadata: None,
449
};
450
451
let mut builder = Builder::new();
452
let ipc_message = builder.finish(&message, None);
453
454
Ok(EncodedData {
455
ipc_message: ipc_message.to_vec(),
456
arrow_data,
457
})
458
}
459
460
/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary
461
/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which
462
/// isn't allowed in the `FileWriter`.
463
pub struct DictionaryTracker {
464
pub dictionaries: Dictionaries,
465
pub cannot_replace: bool,
466
}
467
468
impl DictionaryTracker {
469
/// Keep track of the dictionary with the given ID and values. Behavior:
470
///
471
/// * If this ID has been written already and has the same data, return `Ok(false)` to indicate
472
/// that the dictionary was not actually inserted (because it's already been seen).
473
/// * If this ID has been written already but with different data, and this tracker is
474
/// configured to return an error, return an error.
475
/// * If the tracker has not been configured to error on replacement or this dictionary
476
/// has never been seen before, return `Ok(true)` to indicate that the dictionary was just
477
/// inserted.
478
pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult<bool> {
479
let values = match array.dtype().to_storage() {
480
ArrowDataType::Dictionary(key_type, _, _) => {
481
match_integer_type!(key_type, |$T| {
482
let array = array
483
.as_any()
484
.downcast_ref::<DictionaryArray<$T>>()
485
.unwrap();
486
array.values()
487
})
488
},
489
_ => unreachable!(),
490
};
491
492
// If a dictionary with this id was already emitted, check if it was the same.
493
if let Some(last) = self.dictionaries.get(&dict_id) {
494
if last.as_ref() == values.as_ref() {
495
// Same dictionary values => no need to emit it again
496
return Ok(false);
497
} else if self.cannot_replace {
498
polars_bail!(InvalidOperation:
499
"Dictionary replacement detected when writing IPC file format. \
500
Arrow IPC files only support a single dictionary for a given field \
501
across all batches."
502
);
503
}
504
};
505
506
self.dictionaries.insert(dict_id, values.clone());
507
Ok(true)
508
}
509
}
510
511
/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data
512
#[derive(Debug, Default)]
513
pub struct EncodedData {
514
/// An encoded ipc::Schema::Message
515
pub ipc_message: Vec<u8>,
516
/// Arrow buffers to be written, should be an empty vec for schema messages
517
pub arrow_data: Vec<u8>,
518
}
519
520
/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data
521
#[derive(Debug, Default)]
522
pub struct EncodedDataBytes {
523
/// An encoded ipc::Schema::Message
524
pub ipc_message: Bytes,
525
/// Arrow buffers to be written, should be an empty vec for schema messages
526
pub arrow_data: Bytes,
527
}
528
529
/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes
530
#[inline]
531
pub(crate) fn pad_to_64(len: usize) -> usize {
532
((len + 63) & !63) - len
533
}
534
535
/// An array [`RecordBatchT`] with optional accompanying IPC fields.
536
#[derive(Debug, Clone, PartialEq)]
537
pub struct Record<'a> {
538
columns: Cow<'a, RecordBatchT<Box<dyn Array>>>,
539
fields: Option<Cow<'a, [IpcField]>>,
540
}
541
542
impl Record<'_> {
543
/// Get the IPC fields for this record.
544
pub fn fields(&self) -> Option<&[IpcField]> {
545
self.fields.as_deref()
546
}
547
548
/// Get the Arrow columns in this record.
549
pub fn columns(&self) -> &RecordBatchT<Box<dyn Array>> {
550
self.columns.borrow()
551
}
552
}
553
554
impl From<RecordBatchT<Box<dyn Array>>> for Record<'static> {
555
fn from(columns: RecordBatchT<Box<dyn Array>>) -> Self {
556
Self {
557
columns: Cow::Owned(columns),
558
fields: None,
559
}
560
}
561
}
562
563
impl<'a, F> From<(RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
564
where
565
F: Into<Cow<'a, [IpcField]>>,
566
{
567
fn from((columns, fields): (RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
568
Self {
569
columns: Cow::Owned(columns),
570
fields: fields.map(|f| f.into()),
571
}
572
}
573
}
574
575
impl<'a, F> From<(&'a RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>
576
where
577
F: Into<Cow<'a, [IpcField]>>,
578
{
579
fn from((columns, fields): (&'a RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {
580
Self {
581
columns: Cow::Borrowed(columns),
582
fields: fields.map(|f| f.into()),
583
}
584
}
585
}
586
587
/// Create an IPC Block. Will panic when size limitations are not met.
588
pub fn arrow_ipc_block(
589
offset: usize,
590
meta_data_length: usize,
591
body_length: usize,
592
) -> arrow_format::ipc::Block {
593
arrow_format::ipc::Block {
594
offset: i64::try_from(offset).unwrap(),
595
meta_data_length: i32::try_from(meta_data_length).unwrap(),
596
body_length: i64::try_from(body_length).unwrap(),
597
}
598
}
599
600