Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-python/src/interop/arrow/to_rust.rs
7889 views
1
use polars_core::POOL;
2
use polars_core::prelude::*;
3
use polars_core::utils::accumulate_dataframes_vertical_unchecked;
4
use polars_core::utils::arrow::ffi;
5
use pyo3::ffi::Py_uintptr_t;
6
use pyo3::prelude::*;
7
use pyo3::types::PyList;
8
use rayon::prelude::*;
9
10
use crate::error::PyPolarsErr;
11
use crate::utils::EnterPolarsExt;
12
13
pub fn field_to_rust_arrow(obj: Bound<'_, PyAny>) -> PyResult<ArrowField> {
14
let mut schema = Box::new(ffi::ArrowSchema::empty());
15
let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;
16
17
// make the conversion through PyArrow's private API
18
obj.call_method1("_export_to_c", (schema_ptr as Py_uintptr_t,))?;
19
let field = unsafe { ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)? };
20
Ok(field)
21
}
22
23
pub fn field_to_rust(obj: Bound<'_, PyAny>) -> PyResult<Field> {
24
field_to_rust_arrow(obj).map(|f| (&f).into())
25
}
26
27
// PyList<Field> which you get by calling `list(schema)`
28
pub fn pyarrow_schema_to_rust(obj: &Bound<'_, PyList>) -> PyResult<Schema> {
29
obj.into_iter().map(field_to_rust).collect()
30
}
31
32
pub fn array_to_rust(obj: &Bound<PyAny>) -> PyResult<ArrayRef> {
33
// prepare a pointer to receive the Array struct
34
let mut array = Box::new(ffi::ArrowArray::empty());
35
let mut schema = Box::new(ffi::ArrowSchema::empty());
36
37
let array_ptr = array.as_mut() as *mut ffi::ArrowArray;
38
let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;
39
40
// make the conversion through PyArrow's private API
41
// this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds
42
obj.call_method1(
43
"_export_to_c",
44
(array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t),
45
)?;
46
47
unsafe {
48
let field = ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)?;
49
let array = ffi::import_array_from_c(*array, field.dtype).map_err(PyPolarsErr::from)?;
50
Ok(array)
51
}
52
}
53
54
pub fn to_rust_df(
55
py: Python<'_>,
56
rb: &[Bound<PyAny>],
57
schema: Bound<PyAny>,
58
) -> PyResult<DataFrame> {
59
let ArrowDataType::Struct(fields) = field_to_rust_arrow(schema)?.dtype else {
60
return Err(PyPolarsErr::Other("invalid top-level schema".into()).into());
61
};
62
63
let schema = ArrowSchema::from_iter(fields.iter().cloned());
64
65
// Verify that field names are not duplicated. Arrow permits duplicate field names, we do not.
66
// Required to uphold safety invariants for unsafe block below.
67
if schema.len() != fields.len() {
68
let mut field_map: PlHashMap<PlSmallStr, u64> = PlHashMap::with_capacity(fields.len());
69
fields.iter().for_each(|field| {
70
field_map
71
.entry(field.name.clone())
72
.and_modify(|c| {
73
*c += 1;
74
})
75
.or_insert(1);
76
});
77
let duplicate_fields: Vec<_> = field_map
78
.into_iter()
79
.filter_map(|(k, v)| (v > 1).then_some(k))
80
.collect();
81
82
return Err(PyPolarsErr::Polars(PolarsError::Duplicate(
83
format!("column appears more than once; names must be unique: {duplicate_fields:?}")
84
.into(),
85
))
86
.into());
87
}
88
89
if rb.is_empty() {
90
let columns = schema
91
.iter_values()
92
.map(|field| {
93
let field = Field::from(field);
94
Series::new_empty(field.name, &field.dtype).into_column()
95
})
96
.collect::<Vec<_>>();
97
98
// no need to check as a record batch has the same guarantees
99
return Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) });
100
}
101
102
let dfs = rb
103
.iter()
104
.map(|rb| {
105
let mut run_parallel = false;
106
107
let columns = (0..schema.len())
108
.map(|i| {
109
let array = rb.call_method1("column", (i,))?;
110
let mut arr = array_to_rust(&array)?;
111
112
// Only the schema contains extension type info, restore.
113
// TODO: nested?
114
let dtype = schema.get_at_index(i).unwrap().1.dtype();
115
if let ArrowDataType::Extension(ext) = dtype {
116
if *arr.dtype() == ext.inner {
117
*arr.dtype_mut() = dtype.clone();
118
}
119
}
120
121
run_parallel |= matches!(
122
arr.dtype(),
123
ArrowDataType::Utf8 | ArrowDataType::Dictionary(_, _, _)
124
);
125
Ok(arr)
126
})
127
.collect::<PyResult<Vec<_>>>()?;
128
129
// we parallelize this part because we can have dtypes that are not zero copy
130
// for instance string -> large-utf8
131
// dict encoded to categorical
132
let columns = if run_parallel {
133
py.enter_polars(|| {
134
POOL.install(|| {
135
columns
136
.into_par_iter()
137
.enumerate()
138
.map(|(i, arr)| {
139
let (_, field) = schema.get_at_index(i).unwrap();
140
let s = unsafe {
141
Series::_try_from_arrow_unchecked_with_md(
142
field.name.clone(),
143
vec![arr],
144
field.dtype(),
145
field.metadata.as_deref(),
146
)
147
}
148
.map_err(PyPolarsErr::from)?
149
.into_column();
150
Ok(s)
151
})
152
.collect::<PyResult<Vec<_>>>()
153
})
154
})
155
} else {
156
columns
157
.into_iter()
158
.enumerate()
159
.map(|(i, arr)| {
160
let (_, field) = schema.get_at_index(i).unwrap();
161
let s = unsafe {
162
Series::_try_from_arrow_unchecked_with_md(
163
field.name.clone(),
164
vec![arr],
165
field.dtype(),
166
field.metadata.as_deref(),
167
)
168
}
169
.map_err(PyPolarsErr::from)?
170
.into_column();
171
Ok(s)
172
})
173
.collect::<PyResult<Vec<_>>>()
174
}?;
175
176
// no need to check as a record batch has the same guarantees
177
Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) })
178
})
179
.collect::<PyResult<Vec<_>>>()?;
180
181
Ok(accumulate_dataframes_vertical_unchecked(dfs))
182
}
183
184