Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/io/ipc/read/schema.rs
8446 views
1
use std::sync::Arc;
2
3
use arrow_format::ipc::planus::ReadAsRoot;
4
use arrow_format::ipc::{FieldRef, FixedSizeListRef, MapRef, TimeRef, TimestampRef, UnionRef};
5
use polars_error::{PolarsResult, polars_bail, polars_err};
6
use polars_utils::pl_str::PlSmallStr;
7
8
use super::super::{IpcField, IpcSchema};
9
use super::{OutOfSpecKind, StreamMetadata};
10
use crate::datatypes::{
11
ArrowDataType, ArrowSchema, Extension, ExtensionType, Field, IntegerType, IntervalUnit,
12
Metadata, TimeUnit, UnionMode, UnionType, get_extension,
13
};
14
15
fn try_unzip_vec<A, B, I: Iterator<Item = PolarsResult<(A, B)>>>(
16
iter: I,
17
) -> PolarsResult<(Vec<A>, Vec<B>)> {
18
let mut a = vec![];
19
let mut b = vec![];
20
for maybe_item in iter {
21
let (a_i, b_i) = maybe_item?;
22
a.push(a_i);
23
b.push(b_i);
24
}
25
26
Ok((a, b))
27
}
28
29
fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> PolarsResult<(Field, IpcField)> {
30
let metadata = read_metadata(&ipc_field)?;
31
32
let extension = metadata.as_ref().and_then(get_extension);
33
34
let (dtype, ipc_field_) = get_dtype(ipc_field, extension, true)?;
35
36
let field = Field {
37
name: PlSmallStr::from_str(
38
ipc_field
39
.name()?
40
.ok_or_else(|| polars_err!(oos = "Every field in IPC must have a name"))?,
41
),
42
dtype,
43
is_nullable: ipc_field.nullable()?,
44
metadata: metadata.map(Arc::new),
45
};
46
47
Ok((field, ipc_field_))
48
}
49
50
fn read_metadata(field: &arrow_format::ipc::FieldRef) -> PolarsResult<Option<Metadata>> {
51
Ok(if let Some(list) = field.custom_metadata()? {
52
let mut metadata_map = Metadata::new();
53
for kv in list {
54
let kv = kv?;
55
if let (Some(k), Some(v)) = (kv.key()?, kv.value()?) {
56
metadata_map.insert(PlSmallStr::from_str(k), PlSmallStr::from_str(v));
57
}
58
}
59
Some(metadata_map)
60
} else {
61
None
62
})
63
}
64
65
fn deserialize_integer(int: arrow_format::ipc::IntRef) -> PolarsResult<IntegerType> {
66
Ok(match (int.bit_width()?, int.is_signed()?) {
67
(8, true) => IntegerType::Int8,
68
(8, false) => IntegerType::UInt8,
69
(16, true) => IntegerType::Int16,
70
(16, false) => IntegerType::UInt16,
71
(32, true) => IntegerType::Int32,
72
(32, false) => IntegerType::UInt32,
73
(64, true) => IntegerType::Int64,
74
(64, false) => IntegerType::UInt64,
75
(128, true) => IntegerType::Int128,
76
(128, false) => IntegerType::UInt128,
77
_ => polars_bail!(oos = "IPC: indexType can only be 8, 16, 32, 64 or 128."),
78
})
79
}
80
81
fn deserialize_timeunit(time_unit: arrow_format::ipc::TimeUnit) -> PolarsResult<TimeUnit> {
82
use arrow_format::ipc::TimeUnit::*;
83
Ok(match time_unit {
84
Second => TimeUnit::Second,
85
Millisecond => TimeUnit::Millisecond,
86
Microsecond => TimeUnit::Microsecond,
87
Nanosecond => TimeUnit::Nanosecond,
88
})
89
}
90
91
fn deserialize_time(time: TimeRef) -> PolarsResult<(ArrowDataType, IpcField)> {
92
let unit = deserialize_timeunit(time.unit()?)?;
93
94
let dtype = match (time.bit_width()?, unit) {
95
(32, TimeUnit::Second) => ArrowDataType::Time32(TimeUnit::Second),
96
(32, TimeUnit::Millisecond) => ArrowDataType::Time32(TimeUnit::Millisecond),
97
(64, TimeUnit::Microsecond) => ArrowDataType::Time64(TimeUnit::Microsecond),
98
(64, TimeUnit::Nanosecond) => ArrowDataType::Time64(TimeUnit::Nanosecond),
99
(bits, precision) => {
100
polars_bail!(ComputeError:
101
"Time type with bit width of {bits} and unit of {precision:?}"
102
)
103
},
104
};
105
Ok((dtype, IpcField::default()))
106
}
107
108
fn deserialize_timestamp(timestamp: TimestampRef) -> PolarsResult<(ArrowDataType, IpcField)> {
109
let timezone = timestamp.timezone()?;
110
let time_unit = deserialize_timeunit(timestamp.unit()?)?;
111
Ok((
112
ArrowDataType::Timestamp(time_unit, timezone.map(PlSmallStr::from_str)),
113
IpcField::default(),
114
))
115
}
116
117
fn deserialize_union(union_: UnionRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
118
let mode = UnionMode::sparse(union_.mode()? == arrow_format::ipc::UnionMode::Sparse);
119
let ids = union_.type_ids()?.map(|x| x.iter().collect());
120
121
let fields = field
122
.children()?
123
.ok_or_else(|| polars_err!(oos = "IPC: Union must contain children"))?;
124
if fields.is_empty() {
125
polars_bail!(oos = "IPC: Union must contain at least one child");
126
}
127
128
let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {
129
let (field, fields) = deserialize_field(field?)?;
130
Ok((field, fields))
131
}))?;
132
let ipc_field = IpcField {
133
fields: ipc_fields,
134
dictionary_id: None,
135
};
136
Ok((
137
ArrowDataType::Union(Box::new(UnionType { fields, ids, mode })),
138
ipc_field,
139
))
140
}
141
142
fn deserialize_map(map: MapRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
143
let is_sorted = map.keys_sorted()?;
144
145
let children = field
146
.children()?
147
.ok_or_else(|| polars_err!(oos = "IPC: Map must contain children"))?;
148
let inner = children
149
.get(0)
150
.ok_or_else(|| polars_err!(oos = "IPC: Map must contain one child"))??;
151
let (field, ipc_field) = deserialize_field(inner)?;
152
153
let dtype = ArrowDataType::Map(Box::new(field), is_sorted);
154
Ok((
155
dtype,
156
IpcField {
157
fields: vec![ipc_field],
158
dictionary_id: None,
159
},
160
))
161
}
162
163
fn deserialize_struct(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
164
let fields = field
165
.children()?
166
.ok_or_else(|| polars_err!(oos = "IPC: Struct must contain children"))?;
167
let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {
168
let (field, fields) = deserialize_field(field?)?;
169
Ok((field, fields))
170
}))?;
171
let ipc_field = IpcField {
172
fields: ipc_fields,
173
dictionary_id: None,
174
};
175
Ok((ArrowDataType::Struct(fields), ipc_field))
176
}
177
178
fn deserialize_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
179
let children = field
180
.children()?
181
.ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;
182
let inner = children
183
.get(0)
184
.ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;
185
let (field, ipc_field) = deserialize_field(inner)?;
186
187
Ok((
188
ArrowDataType::List(Box::new(field)),
189
IpcField {
190
fields: vec![ipc_field],
191
dictionary_id: None,
192
},
193
))
194
}
195
196
fn deserialize_large_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
197
let children = field
198
.children()?
199
.ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;
200
let inner = children
201
.get(0)
202
.ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;
203
let (field, ipc_field) = deserialize_field(inner)?;
204
205
Ok((
206
ArrowDataType::LargeList(Box::new(field)),
207
IpcField {
208
fields: vec![ipc_field],
209
dictionary_id: None,
210
},
211
))
212
}
213
214
fn deserialize_fixed_size_list(
215
list: FixedSizeListRef,
216
field: FieldRef,
217
) -> PolarsResult<(ArrowDataType, IpcField)> {
218
let children = field
219
.children()?
220
.ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain children"))?;
221
let inner = children
222
.get(0)
223
.ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain one child"))??;
224
let (field, ipc_field) = deserialize_field(inner)?;
225
226
let size = list
227
.list_size()?
228
.try_into()
229
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
230
231
Ok((
232
ArrowDataType::FixedSizeList(Box::new(field), size),
233
IpcField {
234
fields: vec![ipc_field],
235
dictionary_id: None,
236
},
237
))
238
}
239
240
/// Get the Arrow data type from the flatbuffer Field table
241
fn get_dtype(
242
field: arrow_format::ipc::FieldRef,
243
extension: Extension,
244
may_be_dictionary: bool,
245
) -> PolarsResult<(ArrowDataType, IpcField)> {
246
if let Some(dictionary) = field.dictionary()? {
247
if may_be_dictionary {
248
let int = dictionary
249
.index_type()?
250
.ok_or_else(|| polars_err!(oos = "indexType is mandatory in Dictionary."))?;
251
let index_type = deserialize_integer(int)?;
252
let (inner, mut ipc_field) = get_dtype(field, extension, false)?;
253
ipc_field.dictionary_id = Some(dictionary.id()?);
254
return Ok((
255
ArrowDataType::Dictionary(index_type, Box::new(inner), dictionary.is_ordered()?),
256
ipc_field,
257
));
258
}
259
}
260
261
if let Some(extension) = extension {
262
let (name, metadata) = extension;
263
let (dtype, fields) = get_dtype(field, None, false)?;
264
return Ok((
265
ArrowDataType::Extension(Box::new(ExtensionType {
266
name,
267
inner: dtype,
268
metadata,
269
})),
270
fields,
271
));
272
}
273
274
let type_ = field
275
.type_()?
276
.ok_or_else(|| polars_err!(oos = "IPC: field type is mandatory"))?;
277
278
use arrow_format::ipc::TypeRef::*;
279
Ok(match type_ {
280
Null(_) => (ArrowDataType::Null, IpcField::default()),
281
Bool(_) => (ArrowDataType::Boolean, IpcField::default()),
282
Int(int) => {
283
let dtype = deserialize_integer(int)?.into();
284
(dtype, IpcField::default())
285
},
286
Binary(_) => (ArrowDataType::Binary, IpcField::default()),
287
LargeBinary(_) => (ArrowDataType::LargeBinary, IpcField::default()),
288
Utf8(_) => (ArrowDataType::Utf8, IpcField::default()),
289
LargeUtf8(_) => (ArrowDataType::LargeUtf8, IpcField::default()),
290
BinaryView(_) => (ArrowDataType::BinaryView, IpcField::default()),
291
Utf8View(_) => (ArrowDataType::Utf8View, IpcField::default()),
292
FixedSizeBinary(fixed) => (
293
ArrowDataType::FixedSizeBinary(
294
fixed
295
.byte_width()?
296
.try_into()
297
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?,
298
),
299
IpcField::default(),
300
),
301
FloatingPoint(float) => {
302
let dtype = match float.precision()? {
303
arrow_format::ipc::Precision::Half => ArrowDataType::Float16,
304
arrow_format::ipc::Precision::Single => ArrowDataType::Float32,
305
arrow_format::ipc::Precision::Double => ArrowDataType::Float64,
306
};
307
(dtype, IpcField::default())
308
},
309
Date(date) => {
310
let dtype = match date.unit()? {
311
arrow_format::ipc::DateUnit::Day => ArrowDataType::Date32,
312
arrow_format::ipc::DateUnit::Millisecond => ArrowDataType::Date64,
313
};
314
(dtype, IpcField::default())
315
},
316
Time(time) => deserialize_time(time)?,
317
Timestamp(timestamp) => deserialize_timestamp(timestamp)?,
318
Interval(interval) => {
319
let dtype = match interval.unit()? {
320
arrow_format::ipc::IntervalUnit::YearMonth => {
321
ArrowDataType::Interval(IntervalUnit::YearMonth)
322
},
323
arrow_format::ipc::IntervalUnit::DayTime => {
324
ArrowDataType::Interval(IntervalUnit::DayTime)
325
},
326
arrow_format::ipc::IntervalUnit::MonthDayNano => {
327
ArrowDataType::Interval(IntervalUnit::MonthDayNano)
328
},
329
};
330
(dtype, IpcField::default())
331
},
332
Duration(duration) => {
333
let time_unit = deserialize_timeunit(duration.unit()?)?;
334
(ArrowDataType::Duration(time_unit), IpcField::default())
335
},
336
Decimal(decimal) => {
337
let bit_width: usize = decimal
338
.bit_width()?
339
.try_into()
340
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
341
let precision: usize = decimal
342
.precision()?
343
.try_into()
344
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
345
let scale: usize = decimal
346
.scale()?
347
.try_into()
348
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
349
350
let dtype = match bit_width {
351
32 => ArrowDataType::Decimal32(precision, scale),
352
64 => ArrowDataType::Decimal64(precision, scale),
353
128 => ArrowDataType::Decimal(precision, scale),
354
256 => ArrowDataType::Decimal256(precision, scale),
355
_ => return Err(polars_err!(oos = OutOfSpecKind::NegativeFooterLength)),
356
};
357
358
(dtype, IpcField::default())
359
},
360
List(_) => deserialize_list(field)?,
361
LargeList(_) => deserialize_large_list(field)?,
362
FixedSizeList(list) => deserialize_fixed_size_list(list, field)?,
363
Struct(_) => deserialize_struct(field)?,
364
Union(union_) => deserialize_union(union_, field)?,
365
Map(map) => deserialize_map(map, field)?,
366
RunEndEncoded(_) => todo!(),
367
LargeListView(_) | ListView(_) => todo!(),
368
})
369
}
370
371
/// Deserialize an flatbuffers-encoded Schema message into [`ArrowSchema`] and [`IpcSchema`].
372
pub fn deserialize_schema(
373
message: &[u8],
374
) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {
375
let message = arrow_format::ipc::MessageRef::read_as_root(message)
376
.map_err(|err| polars_err!(oos = format!("Unable deserialize message: {err:?}")))?;
377
378
let schema = match message
379
.header()?
380
.ok_or_else(|| polars_err!(oos = "Unable to convert header to a schema".to_string()))?
381
{
382
arrow_format::ipc::MessageHeaderRef::Schema(schema) => PolarsResult::Ok(schema),
383
_ => polars_bail!(ComputeError: "The message is expected to be a Schema message"),
384
}?;
385
386
fb_to_schema(schema)
387
}
388
389
/// Deserialize the raw Schema table from IPC format to Schema data type
390
pub(super) fn fb_to_schema(
391
schema: arrow_format::ipc::SchemaRef,
392
) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {
393
let fields = schema
394
.fields()?
395
.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingFields))?;
396
397
let mut arrow_schema = ArrowSchema::with_capacity(fields.len());
398
let mut ipc_fields = Vec::with_capacity(fields.len());
399
400
for field in fields {
401
let (field, ipc_field) = deserialize_field(field?)?;
402
arrow_schema.insert(field.name.clone(), field);
403
ipc_fields.push(ipc_field);
404
}
405
406
let is_little_endian = match schema.endianness()? {
407
arrow_format::ipc::Endianness::Little => true,
408
arrow_format::ipc::Endianness::Big => false,
409
};
410
411
let custom_schema_metadata = match schema.custom_metadata()? {
412
None => None,
413
Some(metadata) => {
414
let metadata: Metadata = metadata
415
.into_iter()
416
.filter_map(|kv_result| {
417
// TODO: silently hiding errors here
418
let kv_ref = kv_result.ok()?;
419
Some((kv_ref.key().ok()??.into(), kv_ref.value().ok()??.into()))
420
})
421
.collect();
422
423
if metadata.is_empty() {
424
None
425
} else {
426
Some(metadata)
427
}
428
},
429
};
430
431
Ok((
432
arrow_schema,
433
IpcSchema {
434
fields: ipc_fields,
435
is_little_endian,
436
},
437
custom_schema_metadata,
438
))
439
}
440
441
pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> PolarsResult<StreamMetadata> {
442
let message = arrow_format::ipc::MessageRef::read_as_root(meta)
443
.map_err(|err| polars_err!(oos = format!("Unable to get root as message: {err:?}")))?;
444
let version = message.version()?;
445
// message header is a Schema, so read it
446
let header = message
447
.header()?
448
.ok_or_else(|| polars_err!(oos = "Unable to read the first IPC message"))?;
449
let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header {
450
schema
451
} else {
452
polars_bail!(oos = "The first IPC message of the stream must be a schema")
453
};
454
let (schema, ipc_schema, custom_schema_metadata) = fb_to_schema(schema)?;
455
456
Ok(StreamMetadata {
457
schema,
458
version,
459
ipc_schema,
460
custom_schema_metadata,
461
})
462
}
463
464