Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-python/src/interop/numpy/utils.rs
8346 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use std::ffi::{c_int, c_void};
3
4
use ndarray::{Dim, Dimension};
5
use numpy::npyffi::PyArrayObject;
6
use numpy::{Element, PY_ARRAY_API, PyArrayDescr, PyArrayDescrMethods, ToNpyDims, npyffi};
7
use polars_core::prelude::*;
8
use pyo3::intern;
9
use pyo3::prelude::*;
10
use pyo3::types::PyTuple;
11
12
pub(super) fn get_numpy_module(py: Python) -> PyResult<Bound<PyModule>> {
13
PyModule::import(py, intern!(py, "numpy"))
14
}
15
16
/// Create a NumPy ndarray view of the data.
17
pub(super) unsafe fn create_borrowed_np_array<I>(
18
py: Python<'_>,
19
dtype: Bound<PyArrayDescr>,
20
mut shape: Dim<I>,
21
flags: c_int,
22
data: *mut c_void,
23
owner: Py<PyAny>,
24
) -> Py<PyAny>
25
where
26
Dim<I>: Dimension + ToNpyDims,
27
{
28
// See: https://numpy.org/doc/stable/reference/c-api/array.html
29
let array = PY_ARRAY_API.PyArray_NewFromDescr(
30
py,
31
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
32
dtype.into_dtype_ptr(),
33
shape.ndim_cint(),
34
shape.as_dims_ptr(),
35
// We don't provide strides, but provide flags that tell c/f-order
36
std::ptr::null_mut(),
37
data,
38
flags,
39
std::ptr::null_mut(),
40
);
41
42
// This keeps the memory alive
43
let owner_ptr = owner.as_ptr();
44
// SetBaseObject steals a reference
45
// so we can forget.
46
std::mem::forget(owner);
47
PY_ARRAY_API.PyArray_SetBaseObject(py, array as *mut PyArrayObject, owner_ptr);
48
49
Py::from_owned_ptr(py, array)
50
}
51
52
/// Returns whether the data type supports creating a NumPy view.
53
pub(super) fn dtype_supports_view(dtype: &DataType) -> bool {
54
match dtype {
55
dt if dt.is_primitive_numeric() => true,
56
DataType::Datetime(_, _) | DataType::Duration(_) => true,
57
DataType::Array(inner, _) => dtype_supports_view(inner.as_ref()),
58
_ => false,
59
}
60
}
61
62
/// Returns whether the Series contains nulls at any level of nesting.
63
///
64
/// Of the nested types, only Array types are handled since only those are relevant for NumPy views.
65
pub(super) fn series_contains_null(s: &Series) -> bool {
66
if s.null_count() > 0 {
67
true
68
} else if let Ok(ca) = s.array() {
69
let s_inner = ca.get_inner();
70
series_contains_null(&s_inner)
71
} else {
72
false
73
}
74
}
75
76
/// Reshape the first dimension of a NumPy array to the given height and width.
77
pub(super) fn reshape_numpy_array(
78
py: Python<'_>,
79
arr: Py<PyAny>,
80
height: usize,
81
width: usize,
82
) -> PyResult<Py<PyAny>> {
83
let shape = arr
84
.getattr(py, intern!(py, "shape"))?
85
.extract::<Vec<usize>>(py)?;
86
87
if shape.len() == 1 {
88
// In this case, we can avoid allocating a Vec.
89
let new_shape = (height, width);
90
arr.call_method1(py, intern!(py, "reshape"), new_shape)
91
} else {
92
let mut new_shape_vec = vec![height, width];
93
for v in &shape[1..] {
94
new_shape_vec.push(*v)
95
}
96
let new_shape = PyTuple::new(py, new_shape_vec)?;
97
arr.call_method1(py, intern!(py, "reshape"), new_shape)
98
}
99
}
100
101
/// Get the NumPy temporal data type associated with the given Polars [`DataType`].
102
pub(super) fn polars_dtype_to_np_temporal_dtype<'py>(
103
py: Python<'py>,
104
dtype: &DataType,
105
) -> Bound<'py, PyArrayDescr> {
106
use numpy::datetime::{Datetime, Timedelta, units};
107
match dtype {
108
DataType::Datetime(TimeUnit::Milliseconds, _) => {
109
Datetime::<units::Milliseconds>::get_dtype(py)
110
},
111
DataType::Datetime(TimeUnit::Microseconds, _) => {
112
Datetime::<units::Microseconds>::get_dtype(py)
113
},
114
DataType::Datetime(TimeUnit::Nanoseconds, _) => {
115
Datetime::<units::Nanoseconds>::get_dtype(py)
116
},
117
DataType::Duration(TimeUnit::Milliseconds) => {
118
Timedelta::<units::Milliseconds>::get_dtype(py)
119
},
120
DataType::Duration(TimeUnit::Microseconds) => {
121
Timedelta::<units::Microseconds>::get_dtype(py)
122
},
123
DataType::Duration(TimeUnit::Nanoseconds) => Timedelta::<units::Nanoseconds>::get_dtype(py),
124
_ => panic!("only Datetime/Duration inputs supported, got {dtype}"),
125
}
126
}
127
128