Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/io/ipc/read/schema.rs
6940 views
1
use std::sync::Arc;
2
3
use arrow_format::ipc::planus::ReadAsRoot;
4
use arrow_format::ipc::{FieldRef, FixedSizeListRef, MapRef, TimeRef, TimestampRef, UnionRef};
5
use polars_error::{PolarsResult, polars_bail, polars_err};
6
use polars_utils::pl_str::PlSmallStr;
7
8
use super::super::{IpcField, IpcSchema};
9
use super::{OutOfSpecKind, StreamMetadata};
10
use crate::datatypes::{
11
ArrowDataType, ArrowSchema, Extension, ExtensionType, Field, IntegerType, IntervalUnit,
12
Metadata, TimeUnit, UnionMode, UnionType, get_extension,
13
};
14
15
fn try_unzip_vec<A, B, I: Iterator<Item = PolarsResult<(A, B)>>>(
16
iter: I,
17
) -> PolarsResult<(Vec<A>, Vec<B>)> {
18
let mut a = vec![];
19
let mut b = vec![];
20
for maybe_item in iter {
21
let (a_i, b_i) = maybe_item?;
22
a.push(a_i);
23
b.push(b_i);
24
}
25
26
Ok((a, b))
27
}
28
29
fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> PolarsResult<(Field, IpcField)> {
30
let metadata = read_metadata(&ipc_field)?;
31
32
let extension = metadata.as_ref().and_then(get_extension);
33
34
let (dtype, ipc_field_) = get_dtype(ipc_field, extension, true)?;
35
36
let field = Field {
37
name: PlSmallStr::from_str(
38
ipc_field
39
.name()?
40
.ok_or_else(|| polars_err!(oos = "Every field in IPC must have a name"))?,
41
),
42
dtype,
43
is_nullable: ipc_field.nullable()?,
44
metadata: metadata.map(Arc::new),
45
};
46
47
Ok((field, ipc_field_))
48
}
49
50
fn read_metadata(field: &arrow_format::ipc::FieldRef) -> PolarsResult<Option<Metadata>> {
51
Ok(if let Some(list) = field.custom_metadata()? {
52
let mut metadata_map = Metadata::new();
53
for kv in list {
54
let kv = kv?;
55
if let (Some(k), Some(v)) = (kv.key()?, kv.value()?) {
56
metadata_map.insert(PlSmallStr::from_str(k), PlSmallStr::from_str(v));
57
}
58
}
59
Some(metadata_map)
60
} else {
61
None
62
})
63
}
64
65
fn deserialize_integer(int: arrow_format::ipc::IntRef) -> PolarsResult<IntegerType> {
66
Ok(match (int.bit_width()?, int.is_signed()?) {
67
(8, true) => IntegerType::Int8,
68
(8, false) => IntegerType::UInt8,
69
(16, true) => IntegerType::Int16,
70
(16, false) => IntegerType::UInt16,
71
(32, true) => IntegerType::Int32,
72
(32, false) => IntegerType::UInt32,
73
(64, true) => IntegerType::Int64,
74
(64, false) => IntegerType::UInt64,
75
(128, true) => IntegerType::Int128,
76
_ => polars_bail!(oos = "IPC: indexType can only be 8, 16, 32, 64 or 128."),
77
})
78
}
79
80
fn deserialize_timeunit(time_unit: arrow_format::ipc::TimeUnit) -> PolarsResult<TimeUnit> {
81
use arrow_format::ipc::TimeUnit::*;
82
Ok(match time_unit {
83
Second => TimeUnit::Second,
84
Millisecond => TimeUnit::Millisecond,
85
Microsecond => TimeUnit::Microsecond,
86
Nanosecond => TimeUnit::Nanosecond,
87
})
88
}
89
90
fn deserialize_time(time: TimeRef) -> PolarsResult<(ArrowDataType, IpcField)> {
91
let unit = deserialize_timeunit(time.unit()?)?;
92
93
let dtype = match (time.bit_width()?, unit) {
94
(32, TimeUnit::Second) => ArrowDataType::Time32(TimeUnit::Second),
95
(32, TimeUnit::Millisecond) => ArrowDataType::Time32(TimeUnit::Millisecond),
96
(64, TimeUnit::Microsecond) => ArrowDataType::Time64(TimeUnit::Microsecond),
97
(64, TimeUnit::Nanosecond) => ArrowDataType::Time64(TimeUnit::Nanosecond),
98
(bits, precision) => {
99
polars_bail!(ComputeError:
100
"Time type with bit width of {bits} and unit of {precision:?}"
101
)
102
},
103
};
104
Ok((dtype, IpcField::default()))
105
}
106
107
fn deserialize_timestamp(timestamp: TimestampRef) -> PolarsResult<(ArrowDataType, IpcField)> {
108
let timezone = timestamp.timezone()?;
109
let time_unit = deserialize_timeunit(timestamp.unit()?)?;
110
Ok((
111
ArrowDataType::Timestamp(time_unit, timezone.map(PlSmallStr::from_str)),
112
IpcField::default(),
113
))
114
}
115
116
fn deserialize_union(union_: UnionRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
117
let mode = UnionMode::sparse(union_.mode()? == arrow_format::ipc::UnionMode::Sparse);
118
let ids = union_.type_ids()?.map(|x| x.iter().collect());
119
120
let fields = field
121
.children()?
122
.ok_or_else(|| polars_err!(oos = "IPC: Union must contain children"))?;
123
if fields.is_empty() {
124
polars_bail!(oos = "IPC: Union must contain at least one child");
125
}
126
127
let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {
128
let (field, fields) = deserialize_field(field?)?;
129
Ok((field, fields))
130
}))?;
131
let ipc_field = IpcField {
132
fields: ipc_fields,
133
dictionary_id: None,
134
};
135
Ok((
136
ArrowDataType::Union(Box::new(UnionType { fields, ids, mode })),
137
ipc_field,
138
))
139
}
140
141
fn deserialize_map(map: MapRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
142
let is_sorted = map.keys_sorted()?;
143
144
let children = field
145
.children()?
146
.ok_or_else(|| polars_err!(oos = "IPC: Map must contain children"))?;
147
let inner = children
148
.get(0)
149
.ok_or_else(|| polars_err!(oos = "IPC: Map must contain one child"))??;
150
let (field, ipc_field) = deserialize_field(inner)?;
151
152
let dtype = ArrowDataType::Map(Box::new(field), is_sorted);
153
Ok((
154
dtype,
155
IpcField {
156
fields: vec![ipc_field],
157
dictionary_id: None,
158
},
159
))
160
}
161
162
fn deserialize_struct(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
163
let fields = field
164
.children()?
165
.ok_or_else(|| polars_err!(oos = "IPC: Struct must contain children"))?;
166
let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {
167
let (field, fields) = deserialize_field(field?)?;
168
Ok((field, fields))
169
}))?;
170
let ipc_field = IpcField {
171
fields: ipc_fields,
172
dictionary_id: None,
173
};
174
Ok((ArrowDataType::Struct(fields), ipc_field))
175
}
176
177
fn deserialize_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
178
let children = field
179
.children()?
180
.ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;
181
let inner = children
182
.get(0)
183
.ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;
184
let (field, ipc_field) = deserialize_field(inner)?;
185
186
Ok((
187
ArrowDataType::List(Box::new(field)),
188
IpcField {
189
fields: vec![ipc_field],
190
dictionary_id: None,
191
},
192
))
193
}
194
195
fn deserialize_large_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
196
let children = field
197
.children()?
198
.ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;
199
let inner = children
200
.get(0)
201
.ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;
202
let (field, ipc_field) = deserialize_field(inner)?;
203
204
Ok((
205
ArrowDataType::LargeList(Box::new(field)),
206
IpcField {
207
fields: vec![ipc_field],
208
dictionary_id: None,
209
},
210
))
211
}
212
213
fn deserialize_fixed_size_list(
214
list: FixedSizeListRef,
215
field: FieldRef,
216
) -> PolarsResult<(ArrowDataType, IpcField)> {
217
let children = field
218
.children()?
219
.ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain children"))?;
220
let inner = children
221
.get(0)
222
.ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain one child"))??;
223
let (field, ipc_field) = deserialize_field(inner)?;
224
225
let size = list
226
.list_size()?
227
.try_into()
228
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
229
230
Ok((
231
ArrowDataType::FixedSizeList(Box::new(field), size),
232
IpcField {
233
fields: vec![ipc_field],
234
dictionary_id: None,
235
},
236
))
237
}
238
239
/// Get the Arrow data type from the flatbuffer Field table
240
fn get_dtype(
241
field: arrow_format::ipc::FieldRef,
242
extension: Extension,
243
may_be_dictionary: bool,
244
) -> PolarsResult<(ArrowDataType, IpcField)> {
245
if let Some(dictionary) = field.dictionary()? {
246
if may_be_dictionary {
247
let int = dictionary
248
.index_type()?
249
.ok_or_else(|| polars_err!(oos = "indexType is mandatory in Dictionary."))?;
250
let index_type = deserialize_integer(int)?;
251
let (inner, mut ipc_field) = get_dtype(field, extension, false)?;
252
ipc_field.dictionary_id = Some(dictionary.id()?);
253
return Ok((
254
ArrowDataType::Dictionary(index_type, Box::new(inner), dictionary.is_ordered()?),
255
ipc_field,
256
));
257
}
258
}
259
260
if let Some(extension) = extension {
261
let (name, metadata) = extension;
262
let (dtype, fields) = get_dtype(field, None, false)?;
263
return Ok((
264
ArrowDataType::Extension(Box::new(ExtensionType {
265
name,
266
inner: dtype,
267
metadata,
268
})),
269
fields,
270
));
271
}
272
273
let type_ = field
274
.type_()?
275
.ok_or_else(|| polars_err!(oos = "IPC: field type is mandatory"))?;
276
277
use arrow_format::ipc::TypeRef::*;
278
Ok(match type_ {
279
Null(_) => (ArrowDataType::Null, IpcField::default()),
280
Bool(_) => (ArrowDataType::Boolean, IpcField::default()),
281
Int(int) => {
282
let dtype = deserialize_integer(int)?.into();
283
(dtype, IpcField::default())
284
},
285
Binary(_) => (ArrowDataType::Binary, IpcField::default()),
286
LargeBinary(_) => (ArrowDataType::LargeBinary, IpcField::default()),
287
Utf8(_) => (ArrowDataType::Utf8, IpcField::default()),
288
LargeUtf8(_) => (ArrowDataType::LargeUtf8, IpcField::default()),
289
BinaryView(_) => (ArrowDataType::BinaryView, IpcField::default()),
290
Utf8View(_) => (ArrowDataType::Utf8View, IpcField::default()),
291
FixedSizeBinary(fixed) => (
292
ArrowDataType::FixedSizeBinary(
293
fixed
294
.byte_width()?
295
.try_into()
296
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?,
297
),
298
IpcField::default(),
299
),
300
FloatingPoint(float) => {
301
let dtype = match float.precision()? {
302
arrow_format::ipc::Precision::Half => ArrowDataType::Float16,
303
arrow_format::ipc::Precision::Single => ArrowDataType::Float32,
304
arrow_format::ipc::Precision::Double => ArrowDataType::Float64,
305
};
306
(dtype, IpcField::default())
307
},
308
Date(date) => {
309
let dtype = match date.unit()? {
310
arrow_format::ipc::DateUnit::Day => ArrowDataType::Date32,
311
arrow_format::ipc::DateUnit::Millisecond => ArrowDataType::Date64,
312
};
313
(dtype, IpcField::default())
314
},
315
Time(time) => deserialize_time(time)?,
316
Timestamp(timestamp) => deserialize_timestamp(timestamp)?,
317
Interval(interval) => {
318
let dtype = match interval.unit()? {
319
arrow_format::ipc::IntervalUnit::YearMonth => {
320
ArrowDataType::Interval(IntervalUnit::YearMonth)
321
},
322
arrow_format::ipc::IntervalUnit::DayTime => {
323
ArrowDataType::Interval(IntervalUnit::DayTime)
324
},
325
arrow_format::ipc::IntervalUnit::MonthDayNano => {
326
ArrowDataType::Interval(IntervalUnit::MonthDayNano)
327
},
328
};
329
(dtype, IpcField::default())
330
},
331
Duration(duration) => {
332
let time_unit = deserialize_timeunit(duration.unit()?)?;
333
(ArrowDataType::Duration(time_unit), IpcField::default())
334
},
335
Decimal(decimal) => {
336
let bit_width: usize = decimal
337
.bit_width()?
338
.try_into()
339
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
340
let precision: usize = decimal
341
.precision()?
342
.try_into()
343
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
344
let scale: usize = decimal
345
.scale()?
346
.try_into()
347
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
348
349
let dtype = match bit_width {
350
32 => ArrowDataType::Decimal32(precision, scale),
351
64 => ArrowDataType::Decimal64(precision, scale),
352
128 => ArrowDataType::Decimal(precision, scale),
353
256 => ArrowDataType::Decimal256(precision, scale),
354
_ => return Err(polars_err!(oos = OutOfSpecKind::NegativeFooterLength)),
355
};
356
357
(dtype, IpcField::default())
358
},
359
List(_) => deserialize_list(field)?,
360
LargeList(_) => deserialize_large_list(field)?,
361
FixedSizeList(list) => deserialize_fixed_size_list(list, field)?,
362
Struct(_) => deserialize_struct(field)?,
363
Union(union_) => deserialize_union(union_, field)?,
364
Map(map) => deserialize_map(map, field)?,
365
RunEndEncoded(_) => todo!(),
366
LargeListView(_) | ListView(_) => todo!(),
367
})
368
}
369
370
/// Deserialize an flatbuffers-encoded Schema message into [`ArrowSchema`] and [`IpcSchema`].
371
pub fn deserialize_schema(
372
message: &[u8],
373
) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {
374
let message = arrow_format::ipc::MessageRef::read_as_root(message)
375
.map_err(|err| polars_err!(oos = format!("Unable deserialize message: {err:?}")))?;
376
377
let schema = match message
378
.header()?
379
.ok_or_else(|| polars_err!(oos = "Unable to convert header to a schema".to_string()))?
380
{
381
arrow_format::ipc::MessageHeaderRef::Schema(schema) => PolarsResult::Ok(schema),
382
_ => polars_bail!(ComputeError: "The message is expected to be a Schema message"),
383
}?;
384
385
fb_to_schema(schema)
386
}
387
388
/// Deserialize the raw Schema table from IPC format to Schema data type
389
pub(super) fn fb_to_schema(
390
schema: arrow_format::ipc::SchemaRef,
391
) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {
392
let fields = schema
393
.fields()?
394
.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingFields))?;
395
396
let mut arrow_schema = ArrowSchema::with_capacity(fields.len());
397
let mut ipc_fields = Vec::with_capacity(fields.len());
398
399
for field in fields {
400
let (field, ipc_field) = deserialize_field(field?)?;
401
arrow_schema.insert(field.name.clone(), field);
402
ipc_fields.push(ipc_field);
403
}
404
405
let is_little_endian = match schema.endianness()? {
406
arrow_format::ipc::Endianness::Little => true,
407
arrow_format::ipc::Endianness::Big => false,
408
};
409
410
let custom_schema_metadata = match schema.custom_metadata()? {
411
None => None,
412
Some(metadata) => {
413
let metadata: Metadata = metadata
414
.into_iter()
415
.filter_map(|kv_result| {
416
// FIXME: silently hiding errors here
417
let kv_ref = kv_result.ok()?;
418
Some((kv_ref.key().ok()??.into(), kv_ref.value().ok()??.into()))
419
})
420
.collect();
421
422
if metadata.is_empty() {
423
None
424
} else {
425
Some(metadata)
426
}
427
},
428
};
429
430
Ok((
431
arrow_schema,
432
IpcSchema {
433
fields: ipc_fields,
434
is_little_endian,
435
},
436
custom_schema_metadata,
437
))
438
}
439
440
pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> PolarsResult<StreamMetadata> {
441
let message = arrow_format::ipc::MessageRef::read_as_root(meta)
442
.map_err(|err| polars_err!(oos = format!("Unable to get root as message: {err:?}")))?;
443
let version = message.version()?;
444
// message header is a Schema, so read it
445
let header = message
446
.header()?
447
.ok_or_else(|| polars_err!(oos = "Unable to read the first IPC message"))?;
448
let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header {
449
schema
450
} else {
451
polars_bail!(oos = "The first IPC message of the stream must be a schema")
452
};
453
let (schema, ipc_schema, custom_schema_metadata) = fb_to_schema(schema)?;
454
455
Ok(StreamMetadata {
456
schema,
457
version,
458
ipc_schema,
459
custom_schema_metadata,
460
})
461
}
462
463