Path: blob/main/crates/polars-python/src/interop/arrow/to_rust.rs
7889 views
use polars_core::POOL;1use polars_core::prelude::*;2use polars_core::utils::accumulate_dataframes_vertical_unchecked;3use polars_core::utils::arrow::ffi;4use pyo3::ffi::Py_uintptr_t;5use pyo3::prelude::*;6use pyo3::types::PyList;7use rayon::prelude::*;89use crate::error::PyPolarsErr;10use crate::utils::EnterPolarsExt;1112pub fn field_to_rust_arrow(obj: Bound<'_, PyAny>) -> PyResult<ArrowField> {13let mut schema = Box::new(ffi::ArrowSchema::empty());14let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;1516// make the conversion through PyArrow's private API17obj.call_method1("_export_to_c", (schema_ptr as Py_uintptr_t,))?;18let field = unsafe { ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)? };19Ok(field)20}2122pub fn field_to_rust(obj: Bound<'_, PyAny>) -> PyResult<Field> {23field_to_rust_arrow(obj).map(|f| (&f).into())24}2526// PyList<Field> which you get by calling `list(schema)`27pub fn pyarrow_schema_to_rust(obj: &Bound<'_, PyList>) -> PyResult<Schema> {28obj.into_iter().map(field_to_rust).collect()29}3031pub fn array_to_rust(obj: &Bound<PyAny>) -> PyResult<ArrayRef> {32// prepare a pointer to receive the Array struct33let mut array = Box::new(ffi::ArrowArray::empty());34let mut schema = Box::new(ffi::ArrowSchema::empty());3536let array_ptr = array.as_mut() as *mut ffi::ArrowArray;37let schema_ptr = schema.as_mut() as *mut ffi::ArrowSchema;3839// make the conversion through PyArrow's private API40// this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds41obj.call_method1(42"_export_to_c",43(array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t),44)?;4546unsafe {47let field = ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)?;48let array = ffi::import_array_from_c(*array, field.dtype).map_err(PyPolarsErr::from)?;49Ok(array)50}51}5253pub fn to_rust_df(54py: Python<'_>,55rb: &[Bound<PyAny>],56schema: Bound<PyAny>,57) -> PyResult<DataFrame> {58let ArrowDataType::Struct(fields) = field_to_rust_arrow(schema)?.dtype else {59return Err(PyPolarsErr::Other("invalid top-level schema".into()).into());60};6162let schema = ArrowSchema::from_iter(fields.iter().cloned());6364// Verify that field names are not duplicated. Arrow permits duplicate field names, we do not.65// Required to uphold safety invariants for unsafe block below.66if schema.len() != fields.len() {67let mut field_map: PlHashMap<PlSmallStr, u64> = PlHashMap::with_capacity(fields.len());68fields.iter().for_each(|field| {69field_map70.entry(field.name.clone())71.and_modify(|c| {72*c += 1;73})74.or_insert(1);75});76let duplicate_fields: Vec<_> = field_map77.into_iter()78.filter_map(|(k, v)| (v > 1).then_some(k))79.collect();8081return Err(PyPolarsErr::Polars(PolarsError::Duplicate(82format!("column appears more than once; names must be unique: {duplicate_fields:?}")83.into(),84))85.into());86}8788if rb.is_empty() {89let columns = schema90.iter_values()91.map(|field| {92let field = Field::from(field);93Series::new_empty(field.name, &field.dtype).into_column()94})95.collect::<Vec<_>>();9697// no need to check as a record batch has the same guarantees98return Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) });99}100101let dfs = rb102.iter()103.map(|rb| {104let mut run_parallel = false;105106let columns = (0..schema.len())107.map(|i| {108let array = rb.call_method1("column", (i,))?;109let mut arr = array_to_rust(&array)?;110111// Only the schema contains extension type info, restore.112// TODO: nested?113let dtype = schema.get_at_index(i).unwrap().1.dtype();114if let ArrowDataType::Extension(ext) = dtype {115if *arr.dtype() == ext.inner {116*arr.dtype_mut() = dtype.clone();117}118}119120run_parallel |= matches!(121arr.dtype(),122ArrowDataType::Utf8 | ArrowDataType::Dictionary(_, _, _)123);124Ok(arr)125})126.collect::<PyResult<Vec<_>>>()?;127128// we parallelize this part because we can have dtypes that are not zero copy129// for instance string -> large-utf8130// dict encoded to categorical131let columns = if run_parallel {132py.enter_polars(|| {133POOL.install(|| {134columns135.into_par_iter()136.enumerate()137.map(|(i, arr)| {138let (_, field) = schema.get_at_index(i).unwrap();139let s = unsafe {140Series::_try_from_arrow_unchecked_with_md(141field.name.clone(),142vec![arr],143field.dtype(),144field.metadata.as_deref(),145)146}147.map_err(PyPolarsErr::from)?148.into_column();149Ok(s)150})151.collect::<PyResult<Vec<_>>>()152})153})154} else {155columns156.into_iter()157.enumerate()158.map(|(i, arr)| {159let (_, field) = schema.get_at_index(i).unwrap();160let s = unsafe {161Series::_try_from_arrow_unchecked_with_md(162field.name.clone(),163vec![arr],164field.dtype(),165field.metadata.as_deref(),166)167}168.map_err(PyPolarsErr::from)?169.into_column();170Ok(s)171})172.collect::<PyResult<Vec<_>>>()173}?;174175// no need to check as a record batch has the same guarantees176Ok(unsafe { DataFrame::new_no_checks_height_from_first(columns) })177})178.collect::<PyResult<Vec<_>>>()?;179180Ok(accumulate_dataframes_vertical_unchecked(dfs))181}182183184