Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/ffi/schema.rs
6939 views
1
use std::collections::BTreeMap;
2
use std::ffi::{CStr, CString};
3
use std::ptr;
4
5
use polars_error::{PolarsResult, polars_bail, polars_err};
6
use polars_utils::pl_str::PlSmallStr;
7
8
use super::ArrowSchema;
9
use crate::datatypes::{
10
ArrowDataType, Extension, ExtensionType, Field, IntegerType, IntervalUnit, Metadata, TimeUnit,
11
UnionMode, UnionType,
12
};
13
14
#[allow(dead_code)]
15
struct SchemaPrivateData {
16
name: CString,
17
format: CString,
18
metadata: Option<Vec<u8>>,
19
children_ptr: Box<[*mut ArrowSchema]>,
20
dictionary: Option<*mut ArrowSchema>,
21
}
22
23
// callback used to drop [ArrowSchema] when it is exported.
24
unsafe extern "C" fn c_release_schema(schema: *mut ArrowSchema) {
25
if schema.is_null() {
26
return;
27
}
28
let schema = &mut *schema;
29
30
let private = Box::from_raw(schema.private_data as *mut SchemaPrivateData);
31
for child in private.children_ptr.iter() {
32
let _ = Box::from_raw(*child);
33
}
34
35
if let Some(ptr) = private.dictionary {
36
let _ = Box::from_raw(ptr);
37
}
38
39
schema.release = None;
40
}
41
42
/// allocate (and hold) the children
43
fn schema_children(dtype: &ArrowDataType, flags: &mut i64) -> Box<[*mut ArrowSchema]> {
44
match dtype {
45
ArrowDataType::List(field)
46
| ArrowDataType::FixedSizeList(field, _)
47
| ArrowDataType::LargeList(field) => {
48
Box::new([Box::into_raw(Box::new(ArrowSchema::new(field.as_ref())))])
49
},
50
ArrowDataType::Map(field, is_sorted) => {
51
*flags += (*is_sorted as i64) * 4;
52
Box::new([Box::into_raw(Box::new(ArrowSchema::new(field.as_ref())))])
53
},
54
ArrowDataType::Struct(fields) => fields
55
.iter()
56
.map(|field| Box::into_raw(Box::new(ArrowSchema::new(field))))
57
.collect::<Box<[_]>>(),
58
ArrowDataType::Union(u) => u
59
.fields
60
.iter()
61
.map(|field| Box::into_raw(Box::new(ArrowSchema::new(field))))
62
.collect::<Box<[_]>>(),
63
ArrowDataType::Extension(ext) => schema_children(&ext.inner, flags),
64
_ => Box::new([]),
65
}
66
}
67
68
impl ArrowSchema {
69
/// creates a new [ArrowSchema]
70
pub(crate) fn new(field: &Field) -> Self {
71
let format = to_format(field.dtype());
72
let name = field.name.clone();
73
74
let mut flags = field.is_nullable as i64 * 2;
75
76
// note: this cannot be done along with the above because the above is fallible and this op leaks.
77
let children_ptr = schema_children(field.dtype(), &mut flags);
78
let n_children = children_ptr.len() as i64;
79
80
let dictionary = if let ArrowDataType::Dictionary(_, values, is_ordered) = field.dtype() {
81
flags += *is_ordered as i64;
82
// we do not store field info in the dict values, so can't recover it all :(
83
let field = Field::new(PlSmallStr::EMPTY, values.as_ref().clone(), true);
84
Some(Box::new(ArrowSchema::new(&field)))
85
} else {
86
None
87
};
88
89
let metadata = field
90
.metadata
91
.as_ref()
92
.map(|inner| (**inner).clone())
93
.unwrap_or_default();
94
95
let metadata = if let ArrowDataType::Extension(ext) = field.dtype() {
96
// append extension information.
97
let mut metadata = metadata;
98
99
// metadata
100
if let Some(extension_metadata) = &ext.metadata {
101
metadata.insert(
102
PlSmallStr::from_static("ARROW:extension:metadata"),
103
extension_metadata.clone(),
104
);
105
}
106
107
metadata.insert(
108
PlSmallStr::from_static("ARROW:extension:name"),
109
ext.name.clone(),
110
);
111
112
Some(metadata_to_bytes(&metadata))
113
} else if !metadata.is_empty() {
114
Some(metadata_to_bytes(&metadata))
115
} else {
116
None
117
};
118
119
let name = CString::new(name.as_bytes()).unwrap();
120
let format = CString::new(format).unwrap();
121
122
let mut private = Box::new(SchemaPrivateData {
123
name,
124
format,
125
metadata,
126
children_ptr,
127
dictionary: dictionary.map(Box::into_raw),
128
});
129
130
// <https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema>
131
Self {
132
format: private.format.as_ptr(),
133
name: private.name.as_ptr(),
134
metadata: private
135
.metadata
136
.as_ref()
137
.map(|x| x.as_ptr())
138
.unwrap_or(std::ptr::null()) as *const ::std::os::raw::c_char,
139
flags,
140
n_children,
141
children: private.children_ptr.as_mut_ptr(),
142
dictionary: private.dictionary.unwrap_or(std::ptr::null_mut()),
143
release: Some(c_release_schema),
144
private_data: Box::into_raw(private) as *mut ::std::os::raw::c_void,
145
}
146
}
147
148
/// create an empty [ArrowSchema]
149
pub fn empty() -> Self {
150
Self {
151
format: std::ptr::null_mut(),
152
name: std::ptr::null_mut(),
153
metadata: std::ptr::null_mut(),
154
flags: 0,
155
n_children: 0,
156
children: ptr::null_mut(),
157
dictionary: std::ptr::null_mut(),
158
release: None,
159
private_data: std::ptr::null_mut(),
160
}
161
}
162
163
pub fn is_null(&self) -> bool {
164
self.private_data.is_null()
165
}
166
167
/// returns the format of this schema.
168
pub(crate) fn format(&self) -> &str {
169
assert!(!self.format.is_null());
170
// safe because the lifetime of `self.format` equals `self`
171
unsafe { CStr::from_ptr(self.format) }
172
.to_str()
173
.expect("The external API has a non-utf8 as format")
174
}
175
176
/// returns the name of this schema.
177
///
178
/// Since this field is optional, `""` is returned if it is not set (as per the spec).
179
pub(crate) fn name(&self) -> &str {
180
if self.name.is_null() {
181
return "";
182
}
183
// safe because the lifetime of `self.name` equals `self`
184
unsafe { CStr::from_ptr(self.name) }.to_str().unwrap()
185
}
186
187
pub(crate) fn child(&self, index: usize) -> &'static Self {
188
assert!(index < self.n_children as usize);
189
unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() }
190
}
191
192
pub(crate) fn dictionary(&self) -> Option<&'static Self> {
193
if self.dictionary.is_null() {
194
return None;
195
};
196
Some(unsafe { self.dictionary.as_ref().unwrap() })
197
}
198
199
pub(crate) fn nullable(&self) -> bool {
200
(self.flags / 2) & 1 == 1
201
}
202
}
203
204
impl Drop for ArrowSchema {
205
fn drop(&mut self) {
206
match self.release {
207
None => (),
208
Some(release) => unsafe { release(self) },
209
};
210
}
211
}
212
213
pub(crate) unsafe fn to_field(schema: &ArrowSchema) -> PolarsResult<Field> {
214
let dictionary = schema.dictionary();
215
let dtype = if let Some(dictionary) = dictionary {
216
let indices = to_integer_type(schema.format())?;
217
let values = to_field(dictionary)?;
218
let is_ordered = schema.flags & 1 == 1;
219
ArrowDataType::Dictionary(indices, Box::new(values.dtype().clone()), is_ordered)
220
} else {
221
to_dtype(schema)?
222
};
223
let (metadata, extension) = unsafe { metadata_from_bytes(schema.metadata) };
224
225
let dtype = if let Some((name, extension_metadata)) = extension {
226
ArrowDataType::Extension(Box::new(ExtensionType {
227
name,
228
inner: dtype,
229
metadata: extension_metadata,
230
}))
231
} else {
232
dtype
233
};
234
235
Ok(Field::new(
236
PlSmallStr::from_str(schema.name()),
237
dtype,
238
schema.nullable(),
239
)
240
.with_metadata(metadata))
241
}
242
243
fn to_integer_type(format: &str) -> PolarsResult<IntegerType> {
244
use IntegerType::*;
245
Ok(match format {
246
"c" => Int8,
247
"C" => UInt8,
248
"s" => Int16,
249
"S" => UInt16,
250
"i" => Int32,
251
"I" => UInt32,
252
"l" => Int64,
253
"L" => UInt64,
254
_ => {
255
polars_bail!(
256
ComputeError:
257
"dictionary indices can only be integers"
258
)
259
},
260
})
261
}
262
263
unsafe fn to_dtype(schema: &ArrowSchema) -> PolarsResult<ArrowDataType> {
264
Ok(match schema.format() {
265
"n" => ArrowDataType::Null,
266
"b" => ArrowDataType::Boolean,
267
"c" => ArrowDataType::Int8,
268
"C" => ArrowDataType::UInt8,
269
"s" => ArrowDataType::Int16,
270
"S" => ArrowDataType::UInt16,
271
"i" => ArrowDataType::Int32,
272
"I" => ArrowDataType::UInt32,
273
"l" => ArrowDataType::Int64,
274
"L" => ArrowDataType::UInt64,
275
"_pli128" => ArrowDataType::Int128,
276
"e" => ArrowDataType::Float16,
277
"f" => ArrowDataType::Float32,
278
"g" => ArrowDataType::Float64,
279
"z" => ArrowDataType::Binary,
280
"Z" => ArrowDataType::LargeBinary,
281
"u" => ArrowDataType::Utf8,
282
"U" => ArrowDataType::LargeUtf8,
283
"tdD" => ArrowDataType::Date32,
284
"tdm" => ArrowDataType::Date64,
285
"tts" => ArrowDataType::Time32(TimeUnit::Second),
286
"ttm" => ArrowDataType::Time32(TimeUnit::Millisecond),
287
"ttu" => ArrowDataType::Time64(TimeUnit::Microsecond),
288
"ttn" => ArrowDataType::Time64(TimeUnit::Nanosecond),
289
"tDs" => ArrowDataType::Duration(TimeUnit::Second),
290
"tDm" => ArrowDataType::Duration(TimeUnit::Millisecond),
291
"tDu" => ArrowDataType::Duration(TimeUnit::Microsecond),
292
"tDn" => ArrowDataType::Duration(TimeUnit::Nanosecond),
293
"tiM" => ArrowDataType::Interval(IntervalUnit::YearMonth),
294
"tiD" => ArrowDataType::Interval(IntervalUnit::DayTime),
295
"vu" => ArrowDataType::Utf8View,
296
"vz" => ArrowDataType::BinaryView,
297
"+l" => {
298
let child = schema.child(0);
299
ArrowDataType::List(Box::new(to_field(child)?))
300
},
301
"+L" => {
302
let child = schema.child(0);
303
ArrowDataType::LargeList(Box::new(to_field(child)?))
304
},
305
"+m" => {
306
let child = schema.child(0);
307
308
let is_sorted = (schema.flags & 4) != 0;
309
ArrowDataType::Map(Box::new(to_field(child)?), is_sorted)
310
},
311
"+s" => {
312
let children = (0..schema.n_children as usize)
313
.map(|x| to_field(schema.child(x)))
314
.collect::<PolarsResult<Vec<_>>>()?;
315
ArrowDataType::Struct(children)
316
},
317
other => {
318
match other.splitn(2, ':').collect::<Vec<_>>()[..] {
319
// Timestamps with no timezone
320
["tss", ""] => ArrowDataType::Timestamp(TimeUnit::Second, None),
321
["tsm", ""] => ArrowDataType::Timestamp(TimeUnit::Millisecond, None),
322
["tsu", ""] => ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
323
["tsn", ""] => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
324
325
// Timestamps with timezone
326
["tss", tz] => {
327
ArrowDataType::Timestamp(TimeUnit::Second, Some(PlSmallStr::from_str(tz)))
328
},
329
["tsm", tz] => {
330
ArrowDataType::Timestamp(TimeUnit::Millisecond, Some(PlSmallStr::from_str(tz)))
331
},
332
["tsu", tz] => {
333
ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(PlSmallStr::from_str(tz)))
334
},
335
["tsn", tz] => {
336
ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some(PlSmallStr::from_str(tz)))
337
},
338
339
["w", size_raw] => {
340
// Example: "w:42" fixed-width binary [42 bytes]
341
let size = size_raw
342
.parse::<usize>()
343
.map_err(|_| polars_err!(ComputeError: "size is not a valid integer"))?;
344
ArrowDataType::FixedSizeBinary(size)
345
},
346
["+w", size_raw] => {
347
// Example: "+w:123" fixed-sized list [123 items]
348
let size = size_raw
349
.parse::<usize>()
350
.map_err(|_| polars_err!(ComputeError: "size is not a valid integer"))?;
351
let child = to_field(schema.child(0))?;
352
ArrowDataType::FixedSizeList(Box::new(child), size)
353
},
354
["d", raw] => {
355
// Decimal
356
let (precision, scale) = match raw.split(',').collect::<Vec<_>>()[..] {
357
[precision_raw, scale_raw] => {
358
// Example: "d:19,10" decimal128 [precision 19, scale 10]
359
(precision_raw, scale_raw)
360
},
361
[precision_raw, scale_raw, width_raw] => {
362
// Example: "d:19,10,NNN" decimal bitwidth = NNN [precision 19, scale 10]
363
// Only bitwdth of 128 currently supported
364
let bit_width = width_raw.parse::<usize>().map_err(|_| {
365
polars_err!(ComputeError: "Decimal bit width is not a valid integer")
366
})?;
367
match bit_width {
368
32 => return Ok(ArrowDataType::Decimal32(
369
precision_raw.parse::<usize>().map_err(|_| {
370
polars_err!(ComputeError: "Decimal precision is not a valid integer")
371
})?,
372
scale_raw.parse::<usize>().map_err(|_| {
373
polars_err!(ComputeError: "Decimal scale is not a valid integer")
374
})?,
375
)),
376
64 => return Ok(ArrowDataType::Decimal64(
377
precision_raw.parse::<usize>().map_err(|_| {
378
polars_err!(ComputeError: "Decimal precision is not a valid integer")
379
})?,
380
scale_raw.parse::<usize>().map_err(|_| {
381
polars_err!(ComputeError: "Decimal scale is not a valid integer")
382
})?,
383
)),
384
256 => return Ok(ArrowDataType::Decimal256(
385
precision_raw.parse::<usize>().map_err(|_| {
386
polars_err!(ComputeError: "Decimal precision is not a valid integer")
387
})?,
388
scale_raw.parse::<usize>().map_err(|_| {
389
polars_err!(ComputeError: "Decimal scale is not a valid integer")
390
})?,
391
)),
392
_ => {},
393
}
394
(precision_raw, scale_raw)
395
},
396
_ => {
397
polars_bail!(ComputeError:
398
"Decimal must contain 2 or 3 comma-separated values"
399
)
400
},
401
};
402
403
ArrowDataType::Decimal(
404
precision.parse::<usize>().map_err(|_| {
405
polars_err!(ComputeError:
406
"Decimal precision is not a valid integer"
407
)
408
})?,
409
scale.parse::<usize>().map_err(|_| {
410
polars_err!(ComputeError:
411
"Decimal scale is not a valid integer"
412
)
413
})?,
414
)
415
},
416
[union_type @ "+us", union_parts] | [union_type @ "+ud", union_parts] => {
417
// union, sparse
418
// Example "+us:I,J,..." sparse union with type ids I,J...
419
// Example: "+ud:I,J,..." dense union with type ids I,J...
420
let mode = UnionMode::sparse(union_type == "+us");
421
let type_ids = union_parts
422
.split(',')
423
.map(|x| {
424
x.parse::<i32>().map_err(|_| {
425
polars_err!(ComputeError:
426
"Union type id is not a valid integer"
427
)
428
})
429
})
430
.collect::<PolarsResult<Vec<_>>>()?;
431
let fields = (0..schema.n_children as usize)
432
.map(|x| to_field(schema.child(x)))
433
.collect::<PolarsResult<Vec<_>>>()?;
434
ArrowDataType::Union(Box::new(UnionType {
435
fields,
436
ids: Some(type_ids),
437
mode,
438
}))
439
},
440
_ => {
441
polars_bail!(ComputeError:
442
"The datatype \"{other}\" is still not supported in Rust implementation",
443
)
444
},
445
}
446
},
447
})
448
}
449
450
/// the inverse of [to_field]
451
fn to_format(dtype: &ArrowDataType) -> String {
452
match dtype {
453
ArrowDataType::Null => "n".to_string(),
454
ArrowDataType::Boolean => "b".to_string(),
455
ArrowDataType::Int8 => "c".to_string(),
456
ArrowDataType::UInt8 => "C".to_string(),
457
ArrowDataType::Int16 => "s".to_string(),
458
ArrowDataType::UInt16 => "S".to_string(),
459
ArrowDataType::Int32 => "i".to_string(),
460
ArrowDataType::UInt32 => "I".to_string(),
461
ArrowDataType::Int64 => "l".to_string(),
462
ArrowDataType::UInt64 => "L".to_string(),
463
// Doesn't exist in arrow, '_pl' prefixed is Polars specific
464
ArrowDataType::Int128 => "_pli128".to_string(),
465
ArrowDataType::Float16 => "e".to_string(),
466
ArrowDataType::Float32 => "f".to_string(),
467
ArrowDataType::Float64 => "g".to_string(),
468
ArrowDataType::Binary => "z".to_string(),
469
ArrowDataType::LargeBinary => "Z".to_string(),
470
ArrowDataType::Utf8 => "u".to_string(),
471
ArrowDataType::LargeUtf8 => "U".to_string(),
472
ArrowDataType::Date32 => "tdD".to_string(),
473
ArrowDataType::Date64 => "tdm".to_string(),
474
ArrowDataType::Time32(TimeUnit::Second) => "tts".to_string(),
475
ArrowDataType::Time32(TimeUnit::Millisecond) => "ttm".to_string(),
476
ArrowDataType::Time32(_) => {
477
unreachable!("Time32 is only supported for seconds and milliseconds")
478
},
479
ArrowDataType::Time64(TimeUnit::Microsecond) => "ttu".to_string(),
480
ArrowDataType::Time64(TimeUnit::Nanosecond) => "ttn".to_string(),
481
ArrowDataType::Time64(_) => {
482
unreachable!("Time64 is only supported for micro and nanoseconds")
483
},
484
ArrowDataType::Duration(TimeUnit::Second) => "tDs".to_string(),
485
ArrowDataType::Duration(TimeUnit::Millisecond) => "tDm".to_string(),
486
ArrowDataType::Duration(TimeUnit::Microsecond) => "tDu".to_string(),
487
ArrowDataType::Duration(TimeUnit::Nanosecond) => "tDn".to_string(),
488
ArrowDataType::Interval(IntervalUnit::YearMonth) => "tiM".to_string(),
489
ArrowDataType::Interval(IntervalUnit::DayTime) => "tiD".to_string(),
490
ArrowDataType::Interval(IntervalUnit::MonthDayNano) => {
491
todo!("Spec for FFI for MonthDayNano still not defined.")
492
},
493
ArrowDataType::Timestamp(unit, tz) => {
494
let unit = match unit {
495
TimeUnit::Second => "s",
496
TimeUnit::Millisecond => "m",
497
TimeUnit::Microsecond => "u",
498
TimeUnit::Nanosecond => "n",
499
};
500
format!(
501
"ts{}:{}",
502
unit,
503
tz.as_ref().map(|x| x.as_str()).unwrap_or("")
504
)
505
},
506
ArrowDataType::Utf8View => "vu".to_string(),
507
ArrowDataType::BinaryView => "vz".to_string(),
508
ArrowDataType::Decimal(precision, scale) => format!("d:{precision},{scale}"),
509
ArrowDataType::Decimal32(precision, scale) => format!("d:{precision},{scale},32"),
510
ArrowDataType::Decimal64(precision, scale) => format!("d:{precision},{scale},64"),
511
ArrowDataType::Decimal256(precision, scale) => format!("d:{precision},{scale},256"),
512
ArrowDataType::List(_) => "+l".to_string(),
513
ArrowDataType::LargeList(_) => "+L".to_string(),
514
ArrowDataType::Struct(_) => "+s".to_string(),
515
ArrowDataType::FixedSizeBinary(size) => format!("w:{size}"),
516
ArrowDataType::FixedSizeList(_, size) => format!("+w:{size}"),
517
ArrowDataType::Union(u) => {
518
let sparsness = if u.mode.is_sparse() { 's' } else { 'd' };
519
let mut r = format!("+u{sparsness}:");
520
let ids = if let Some(ids) = &u.ids {
521
ids.iter()
522
.fold(String::new(), |a, b| a + b.to_string().as_str() + ",")
523
} else {
524
(0..u.fields.len()).fold(String::new(), |a, b| a + b.to_string().as_str() + ",")
525
};
526
let ids = &ids[..ids.len() - 1]; // take away last ","
527
r.push_str(ids);
528
r
529
},
530
ArrowDataType::Map(_, _) => "+m".to_string(),
531
ArrowDataType::Dictionary(index, _, _) => to_format(&(*index).into()),
532
ArrowDataType::Extension(ext) => to_format(&ext.inner),
533
ArrowDataType::Unknown => unimplemented!(),
534
}
535
}
536
537
pub(super) fn get_child(dtype: &ArrowDataType, index: usize) -> PolarsResult<ArrowDataType> {
538
match (index, dtype) {
539
(0, ArrowDataType::List(field)) => Ok(field.dtype().clone()),
540
(0, ArrowDataType::FixedSizeList(field, _)) => Ok(field.dtype().clone()),
541
(0, ArrowDataType::LargeList(field)) => Ok(field.dtype().clone()),
542
(0, ArrowDataType::Map(field, _)) => Ok(field.dtype().clone()),
543
(index, ArrowDataType::Struct(fields)) => Ok(fields[index].dtype().clone()),
544
(index, ArrowDataType::Union(u)) => Ok(u.fields[index].dtype().clone()),
545
(index, ArrowDataType::Extension(ext)) => get_child(&ext.inner, index),
546
(child, dtype) => polars_bail!(ComputeError:
547
"Requested child {child} to type {dtype:?} that has no such child",
548
),
549
}
550
}
551
552
fn metadata_to_bytes(metadata: &BTreeMap<PlSmallStr, PlSmallStr>) -> Vec<u8> {
553
let a = (metadata.len() as i32).to_ne_bytes().to_vec();
554
metadata.iter().fold(a, |mut acc, (key, value)| {
555
acc.extend((key.len() as i32).to_ne_bytes());
556
acc.extend(key.as_bytes());
557
acc.extend((value.len() as i32).to_ne_bytes());
558
acc.extend(value.as_bytes());
559
acc
560
})
561
}
562
563
unsafe fn read_ne_i32(ptr: *const u8) -> i32 {
564
let slice = std::slice::from_raw_parts(ptr, 4);
565
i32::from_ne_bytes(slice.try_into().unwrap())
566
}
567
568
unsafe fn read_bytes(ptr: *const u8, len: usize) -> &'static str {
569
let slice = std::slice::from_raw_parts(ptr, len);
570
simdutf8::basic::from_utf8(slice).unwrap()
571
}
572
573
unsafe fn metadata_from_bytes(data: *const ::std::os::raw::c_char) -> (Metadata, Extension) {
574
let mut data = data as *const u8; // u8 = i8
575
if data.is_null() {
576
return (Metadata::default(), None);
577
};
578
let len = read_ne_i32(data);
579
data = data.add(4);
580
581
let mut result = BTreeMap::new();
582
let mut extension_name = None;
583
let mut extension_metadata = None;
584
for _ in 0..len {
585
let key_len = read_ne_i32(data) as usize;
586
data = data.add(4);
587
let key = read_bytes(data, key_len);
588
data = data.add(key_len);
589
let value_len = read_ne_i32(data) as usize;
590
data = data.add(4);
591
let value = read_bytes(data, value_len);
592
data = data.add(value_len);
593
match key {
594
"ARROW:extension:name" => {
595
extension_name = Some(PlSmallStr::from_str(value));
596
},
597
"ARROW:extension:metadata" => {
598
extension_metadata = Some(PlSmallStr::from_str(value));
599
},
600
_ => {
601
result.insert(PlSmallStr::from_str(key), PlSmallStr::from_str(value));
602
},
603
};
604
}
605
let extension = extension_name.map(|name| (name, extension_metadata));
606
(result, extension)
607
}
608
609
#[cfg(test)]
610
mod tests {
611
use super::*;
612
use crate::array::LIST_VALUES_NAME;
613
614
#[test]
615
fn test_all() {
616
let mut dts = vec![
617
ArrowDataType::Null,
618
ArrowDataType::Boolean,
619
ArrowDataType::UInt8,
620
ArrowDataType::UInt16,
621
ArrowDataType::UInt32,
622
ArrowDataType::UInt64,
623
ArrowDataType::Int8,
624
ArrowDataType::Int16,
625
ArrowDataType::Int32,
626
ArrowDataType::Int64,
627
ArrowDataType::Float32,
628
ArrowDataType::Float64,
629
ArrowDataType::Date32,
630
ArrowDataType::Date64,
631
ArrowDataType::Time32(TimeUnit::Second),
632
ArrowDataType::Time32(TimeUnit::Millisecond),
633
ArrowDataType::Time64(TimeUnit::Microsecond),
634
ArrowDataType::Time64(TimeUnit::Nanosecond),
635
ArrowDataType::Decimal(5, 5),
636
ArrowDataType::Utf8,
637
ArrowDataType::LargeUtf8,
638
ArrowDataType::Binary,
639
ArrowDataType::LargeBinary,
640
ArrowDataType::FixedSizeBinary(2),
641
ArrowDataType::List(Box::new(Field::new(
642
PlSmallStr::from_static("example"),
643
ArrowDataType::Boolean,
644
false,
645
))),
646
ArrowDataType::FixedSizeList(
647
Box::new(Field::new(
648
PlSmallStr::from_static("example"),
649
ArrowDataType::Boolean,
650
false,
651
)),
652
2,
653
),
654
ArrowDataType::LargeList(Box::new(Field::new(
655
PlSmallStr::from_static("example"),
656
ArrowDataType::Boolean,
657
false,
658
))),
659
ArrowDataType::Struct(vec![
660
Field::new(PlSmallStr::from_static("a"), ArrowDataType::Int64, true),
661
Field::new(
662
PlSmallStr::from_static("b"),
663
ArrowDataType::List(Box::new(Field::new(
664
LIST_VALUES_NAME,
665
ArrowDataType::Int32,
666
true,
667
))),
668
true,
669
),
670
]),
671
ArrowDataType::Map(
672
Box::new(Field::new(
673
PlSmallStr::from_static("a"),
674
ArrowDataType::Int64,
675
true,
676
)),
677
true,
678
),
679
ArrowDataType::Union(Box::new(UnionType {
680
fields: vec![
681
Field::new(PlSmallStr::from_static("a"), ArrowDataType::Int64, true),
682
Field::new(
683
PlSmallStr::from_static("b"),
684
ArrowDataType::List(Box::new(Field::new(
685
LIST_VALUES_NAME,
686
ArrowDataType::Int32,
687
true,
688
))),
689
true,
690
),
691
],
692
ids: Some(vec![1, 2]),
693
mode: UnionMode::Dense,
694
})),
695
ArrowDataType::Union(Box::new(UnionType {
696
fields: vec![
697
Field::new(PlSmallStr::from_static("a"), ArrowDataType::Int64, true),
698
Field::new(
699
PlSmallStr::from_static("b"),
700
ArrowDataType::List(Box::new(Field::new(
701
LIST_VALUES_NAME,
702
ArrowDataType::Int32,
703
true,
704
))),
705
true,
706
),
707
],
708
ids: Some(vec![0, 1]),
709
mode: UnionMode::Sparse,
710
})),
711
];
712
for time_unit in [
713
TimeUnit::Second,
714
TimeUnit::Millisecond,
715
TimeUnit::Microsecond,
716
TimeUnit::Nanosecond,
717
] {
718
dts.push(ArrowDataType::Timestamp(time_unit, None));
719
dts.push(ArrowDataType::Timestamp(
720
time_unit,
721
Some(PlSmallStr::from_static("00:00")),
722
));
723
dts.push(ArrowDataType::Duration(time_unit));
724
}
725
for interval_type in [
726
IntervalUnit::DayTime,
727
IntervalUnit::YearMonth,
728
//IntervalUnit::MonthDayNano, // not yet defined on the C data interface
729
] {
730
dts.push(ArrowDataType::Interval(interval_type));
731
}
732
733
for expected in dts {
734
let field = Field::new(PlSmallStr::from_static("a"), expected.clone(), true);
735
let schema = ArrowSchema::new(&field);
736
let result = unsafe { super::to_dtype(&schema).unwrap() };
737
assert_eq!(result, expected);
738
}
739
}
740
}
741
742