Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-parquet/src/arrow/write/schema.rs
6940 views
1
use std::borrow::Cow;
2
use std::sync::Arc;
3
4
use arrow::datatypes::{ArrowDataType, ArrowSchema, ExtensionType, Field, TimeUnit};
5
use arrow::io::ipc::write::{default_ipc_fields, schema_to_bytes};
6
use base64::Engine as _;
7
use base64::engine::general_purpose;
8
use polars_error::{PolarsResult, polars_bail};
9
use polars_utils::pl_str::PlSmallStr;
10
11
use super::super::ARROW_SCHEMA_META_KEY;
12
use super::ColumnWriteOptions;
13
use crate::arrow::write::decimal_length_from_precision;
14
use crate::parquet::metadata::KeyValue;
15
use crate::parquet::schema::Repetition;
16
use crate::parquet::schema::types::{
17
GroupConvertedType, GroupLogicalType, IntegerType, ParquetType, PhysicalType,
18
PrimitiveConvertedType, PrimitiveLogicalType, TimeUnit as ParquetTimeUnit,
19
};
20
use crate::write::ChildWriteOptions;
21
22
fn convert_field(field: Field) -> Field {
23
Field {
24
name: field.name,
25
dtype: convert_dtype(field.dtype),
26
is_nullable: field.is_nullable,
27
metadata: field.metadata,
28
}
29
}
30
31
fn convert_dtype(dtype: ArrowDataType) -> ArrowDataType {
32
use ArrowDataType as D;
33
match dtype {
34
D::LargeList(field) => D::LargeList(Box::new(convert_field(*field))),
35
D::Struct(mut fields) => {
36
for field in &mut fields {
37
*field = convert_field(std::mem::take(field))
38
}
39
D::Struct(fields)
40
},
41
D::BinaryView => D::LargeBinary,
42
D::Utf8View => D::LargeUtf8,
43
D::Dictionary(it, dtype, sorted) => {
44
let dtype = convert_dtype(*dtype);
45
D::Dictionary(it, Box::new(dtype), sorted)
46
},
47
D::Extension(ext) => {
48
let dtype = convert_dtype(ext.inner);
49
D::Extension(Box::new(ExtensionType {
50
inner: dtype,
51
..*ext
52
}))
53
},
54
dt => dt,
55
}
56
}
57
58
fn insert_field_metadata(field: &mut Cow<Field>, options: &ColumnWriteOptions) {
59
if !options.metadata.is_empty() {
60
let field = field.to_mut();
61
let mut metadata = field.metadata.as_deref().cloned().unwrap_or_default();
62
63
for kv in &options.metadata {
64
metadata.insert(
65
kv.key.as_str().into(),
66
kv.value.as_deref().unwrap_or_default().into(),
67
);
68
}
69
field.metadata = Some(Arc::new(metadata));
70
}
71
72
if let Some(v) = options.required {
73
if v == field.is_nullable {
74
let field = field.to_mut();
75
field.is_nullable = !v;
76
}
77
}
78
79
use ArrowDataType as D;
80
match field.dtype() {
81
D::Struct(f) => {
82
let ChildWriteOptions::Struct(o) = &options.children else {
83
unreachable!();
84
};
85
86
let mut new_fields = Vec::new();
87
for (i, (child_field, child_options)) in f.iter().zip(o.children.as_slice()).enumerate()
88
{
89
let mut child_field = Cow::Borrowed(child_field);
90
insert_field_metadata(&mut child_field, child_options);
91
92
if let Cow::Owned(child_field) = child_field {
93
new_fields.reserve(f.len());
94
new_fields.extend(f[..i].iter().cloned());
95
new_fields.push(child_field);
96
break;
97
}
98
}
99
100
if new_fields.is_empty() {
101
return;
102
}
103
104
new_fields.extend(
105
f[new_fields.len()..]
106
.iter()
107
.zip(&o.children[new_fields.len()..])
108
.map(|(child_field, child_options)| {
109
let mut child_field = Cow::Borrowed(child_field);
110
insert_field_metadata(&mut child_field, child_options);
111
child_field.into_owned()
112
}),
113
);
114
field
115
.to_mut()
116
.map_dtype_mut(|dtype| *dtype = D::Struct(new_fields));
117
},
118
D::List(f) | D::FixedSizeList(f, _) | D::LargeList(f) => {
119
let ChildWriteOptions::ListLike(o) = &options.children else {
120
unreachable!();
121
};
122
123
let mut child_field = Cow::Borrowed(f.as_ref());
124
insert_field_metadata(&mut child_field, &o.child);
125
126
if let Cow::Owned(child_field) = child_field {
127
let child_field = Box::new(child_field);
128
field.to_mut().map_dtype_mut(|dtype| {
129
*dtype = match dtype {
130
D::List(_) => D::List(child_field),
131
D::LargeList(_) => D::LargeList(child_field),
132
D::FixedSizeList(_, width) => D::FixedSizeList(child_field, *width),
133
_ => unreachable!(),
134
}
135
});
136
}
137
},
138
_ => {},
139
}
140
}
141
142
pub fn schema_to_metadata_key(schema: &ArrowSchema, options: &[ColumnWriteOptions]) -> KeyValue {
143
let mut schema_mut = None;
144
for (f, options) in schema.iter_values().zip(options) {
145
let mut field = Cow::Borrowed(f);
146
insert_field_metadata(&mut field, options);
147
148
if let Cow::Owned(field) = field {
149
let schema_mut = schema_mut.get_or_insert_with(|| schema.clone());
150
*schema_mut.get_mut(f.name.as_str()).unwrap() = field;
151
}
152
}
153
154
let mut schema = schema;
155
if let Some(schema_mut) = &schema_mut {
156
schema = schema_mut;
157
}
158
159
// Convert schema until more arrow readers are aware of binview
160
let serialized_schema = if schema.iter_values().any(|field| field.dtype.is_view()) {
161
let schema = schema
162
.iter_values()
163
.map(|field| convert_field(field.clone()))
164
.map(|x| (x.name.clone(), x))
165
.collect();
166
schema_to_bytes(&schema, &default_ipc_fields(schema.iter_values()), None)
167
} else {
168
schema_to_bytes(schema, &default_ipc_fields(schema.iter_values()), None)
169
};
170
171
// manually prepending the length to the schema as arrow uses the legacy IPC format
172
// TODO: change after addressing ARROW-9777
173
let schema_len = serialized_schema.len();
174
let mut len_prefix_schema = Vec::with_capacity(schema_len + 8);
175
len_prefix_schema.extend_from_slice(&[255u8, 255, 255, 255]);
176
len_prefix_schema.extend_from_slice(&(schema_len as u32).to_le_bytes());
177
len_prefix_schema.extend_from_slice(&serialized_schema);
178
179
let encoded = general_purpose::STANDARD.encode(&len_prefix_schema);
180
181
KeyValue {
182
key: ARROW_SCHEMA_META_KEY.to_string(),
183
value: Some(encoded),
184
}
185
}
186
187
/// Creates a [`ParquetType`] from a [`Field`].
188
pub fn to_parquet_type(field: &Field, options: &ColumnWriteOptions) -> PolarsResult<ParquetType> {
189
let name = field.name.clone();
190
let repetition = if options.required.unwrap_or(!field.is_nullable) {
191
Repetition::Required
192
} else {
193
Repetition::Optional
194
};
195
196
let field_id = options.field_id;
197
198
// create type from field
199
let (physical_type, primitive_converted_type, primitive_logical_type) = match field
200
.dtype()
201
.to_logical_type()
202
{
203
ArrowDataType::Null => (
204
PhysicalType::Int32,
205
None,
206
Some(PrimitiveLogicalType::Unknown),
207
),
208
ArrowDataType::Boolean => (PhysicalType::Boolean, None, None),
209
ArrowDataType::Int32 => (PhysicalType::Int32, None, None),
210
// ArrowDataType::Duration(_) has no parquet representation => do not apply any logical type
211
ArrowDataType::Int64 | ArrowDataType::Duration(_) => (PhysicalType::Int64, None, None),
212
// no natural representation in parquet; leave it as is.
213
// arrow consumers MAY use the arrow schema in the metadata to parse them.
214
ArrowDataType::Date64 => (PhysicalType::Int64, None, None),
215
ArrowDataType::Float32 => (PhysicalType::Float, None, None),
216
ArrowDataType::Float64 => (PhysicalType::Double, None, None),
217
ArrowDataType::Binary | ArrowDataType::LargeBinary | ArrowDataType::BinaryView => {
218
(PhysicalType::ByteArray, None, None)
219
},
220
ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View => (
221
PhysicalType::ByteArray,
222
Some(PrimitiveConvertedType::Utf8),
223
Some(PrimitiveLogicalType::String),
224
),
225
ArrowDataType::Date32 => (
226
PhysicalType::Int32,
227
Some(PrimitiveConvertedType::Date),
228
Some(PrimitiveLogicalType::Date),
229
),
230
ArrowDataType::Int8 => (
231
PhysicalType::Int32,
232
Some(PrimitiveConvertedType::Int8),
233
Some(PrimitiveLogicalType::Integer(IntegerType::Int8)),
234
),
235
ArrowDataType::Int16 => (
236
PhysicalType::Int32,
237
Some(PrimitiveConvertedType::Int16),
238
Some(PrimitiveLogicalType::Integer(IntegerType::Int16)),
239
),
240
ArrowDataType::UInt8 => (
241
PhysicalType::Int32,
242
Some(PrimitiveConvertedType::Uint8),
243
Some(PrimitiveLogicalType::Integer(IntegerType::UInt8)),
244
),
245
ArrowDataType::UInt16 => (
246
PhysicalType::Int32,
247
Some(PrimitiveConvertedType::Uint16),
248
Some(PrimitiveLogicalType::Integer(IntegerType::UInt16)),
249
),
250
ArrowDataType::UInt32 => (
251
PhysicalType::Int32,
252
Some(PrimitiveConvertedType::Uint32),
253
Some(PrimitiveLogicalType::Integer(IntegerType::UInt32)),
254
),
255
ArrowDataType::UInt64 => (
256
PhysicalType::Int64,
257
Some(PrimitiveConvertedType::Uint64),
258
Some(PrimitiveLogicalType::Integer(IntegerType::UInt64)),
259
),
260
// no natural representation in parquet; leave it as is.
261
// arrow consumers MAY use the arrow schema in the metadata to parse them.
262
ArrowDataType::Timestamp(TimeUnit::Second, _) => (PhysicalType::Int64, None, None),
263
ArrowDataType::Timestamp(time_unit, zone) => (
264
PhysicalType::Int64,
265
None,
266
Some(PrimitiveLogicalType::Timestamp {
267
is_adjusted_to_utc: matches!(zone, Some(z) if !z.as_str().is_empty()),
268
unit: match time_unit {
269
TimeUnit::Second => unreachable!(),
270
TimeUnit::Millisecond => ParquetTimeUnit::Milliseconds,
271
TimeUnit::Microsecond => ParquetTimeUnit::Microseconds,
272
TimeUnit::Nanosecond => ParquetTimeUnit::Nanoseconds,
273
},
274
}),
275
),
276
// no natural representation in parquet; leave it as is.
277
// arrow consumers MAY use the arrow schema in the metadata to parse them.
278
ArrowDataType::Time32(TimeUnit::Second) => (PhysicalType::Int32, None, None),
279
ArrowDataType::Time32(TimeUnit::Millisecond) => (
280
PhysicalType::Int32,
281
Some(PrimitiveConvertedType::TimeMillis),
282
Some(PrimitiveLogicalType::Time {
283
is_adjusted_to_utc: false,
284
unit: ParquetTimeUnit::Milliseconds,
285
}),
286
),
287
ArrowDataType::Time64(time_unit) => (
288
PhysicalType::Int64,
289
match time_unit {
290
TimeUnit::Microsecond => Some(PrimitiveConvertedType::TimeMicros),
291
TimeUnit::Nanosecond => None,
292
_ => unreachable!(),
293
},
294
Some(PrimitiveLogicalType::Time {
295
is_adjusted_to_utc: false,
296
unit: match time_unit {
297
TimeUnit::Microsecond => ParquetTimeUnit::Microseconds,
298
TimeUnit::Nanosecond => ParquetTimeUnit::Nanoseconds,
299
_ => unreachable!(),
300
},
301
}),
302
),
303
ArrowDataType::Struct(fields) => {
304
if fields.is_empty() {
305
polars_bail!(InvalidOperation:
306
"Unable to write struct type with no child field to Parquet. Consider adding a dummy child field.".to_string(),
307
)
308
}
309
310
let ChildWriteOptions::Struct(struct_write_options) = &options.children else {
311
unreachable!();
312
};
313
314
assert_eq!(fields.len(), struct_write_options.children.len());
315
316
// recursively convert children to types/nodes
317
let fields = fields
318
.iter()
319
.zip(struct_write_options.children.as_slice())
320
.map(|(f, c)| to_parquet_type(f, c))
321
.collect::<PolarsResult<Vec<_>>>()?;
322
return Ok(ParquetType::from_group(
323
name, repetition, None, None, fields, field_id,
324
));
325
},
326
ArrowDataType::Dictionary(_, value, _) => {
327
assert!(!value.is_nested());
328
let dict_field = Field::new(name, value.as_ref().clone(), field.is_nullable);
329
return to_parquet_type(&dict_field, options);
330
},
331
ArrowDataType::FixedSizeBinary(size) => {
332
(PhysicalType::FixedLenByteArray(*size), None, None)
333
},
334
ArrowDataType::Decimal(precision, scale) => {
335
let precision = *precision;
336
let scale = *scale;
337
let logical_type = Some(PrimitiveLogicalType::Decimal(precision, scale));
338
339
let physical_type = if precision <= 9 {
340
PhysicalType::Int32
341
} else if precision <= 18 {
342
PhysicalType::Int64
343
} else {
344
let len = decimal_length_from_precision(precision);
345
PhysicalType::FixedLenByteArray(len)
346
};
347
(
348
physical_type,
349
Some(PrimitiveConvertedType::Decimal(precision, scale)),
350
logical_type,
351
)
352
},
353
ArrowDataType::Decimal256(precision, scale) => {
354
let precision = *precision;
355
let scale = *scale;
356
let logical_type = Some(PrimitiveLogicalType::Decimal(precision, scale));
357
358
if precision <= 9 {
359
(
360
PhysicalType::Int32,
361
Some(PrimitiveConvertedType::Decimal(precision, scale)),
362
logical_type,
363
)
364
} else if precision <= 18 {
365
(
366
PhysicalType::Int64,
367
Some(PrimitiveConvertedType::Decimal(precision, scale)),
368
logical_type,
369
)
370
} else if precision <= 38 {
371
let len = decimal_length_from_precision(precision);
372
(
373
PhysicalType::FixedLenByteArray(len),
374
Some(PrimitiveConvertedType::Decimal(precision, scale)),
375
logical_type,
376
)
377
} else {
378
(PhysicalType::FixedLenByteArray(32), None, None)
379
}
380
},
381
ArrowDataType::Interval(_) => (
382
PhysicalType::FixedLenByteArray(12),
383
Some(PrimitiveConvertedType::Interval),
384
None,
385
),
386
ArrowDataType::Int128 => (PhysicalType::FixedLenByteArray(16), None, None),
387
ArrowDataType::List(f)
388
| ArrowDataType::FixedSizeList(f, _)
389
| ArrowDataType::LargeList(f) => {
390
let mut f = f.clone();
391
f.name = PlSmallStr::from_static("element");
392
393
let ChildWriteOptions::ListLike(list_write_options) = &options.children else {
394
unreachable!();
395
};
396
397
return Ok(ParquetType::from_group(
398
name,
399
repetition,
400
Some(GroupConvertedType::List),
401
Some(GroupLogicalType::List),
402
vec![ParquetType::from_group(
403
PlSmallStr::from_static("list"),
404
Repetition::Repeated,
405
None,
406
None,
407
vec![to_parquet_type(&f, &list_write_options.child)?],
408
None,
409
)],
410
field_id,
411
));
412
},
413
other => polars_bail!(nyi = "Writing the data type {other:?} is not yet implemented"),
414
};
415
416
Ok(ParquetType::try_from_primitive(
417
name,
418
physical_type,
419
repetition,
420
primitive_converted_type,
421
primitive_logical_type,
422
field_id,
423
)?)
424
}
425
426