Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-python/src/on_startup.rs
7884 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use std::any::Any;
3
use std::sync::OnceLock;
4
5
use arrow::array::Array;
6
use polars::chunked_array::object::ObjectArray;
7
use polars::prelude::sink2::FileProviderReturn;
8
use polars::prelude::*;
9
use polars_core::chunked_array::object::builder::ObjectChunkedBuilder;
10
use polars_core::chunked_array::object::registry::AnonymousObjectBuilder;
11
use polars_core::chunked_array::object::{registry, set_polars_allow_extension};
12
use polars_error::PolarsWarning;
13
use polars_error::signals::register_polars_keyboard_interrupt_hook;
14
use polars_ffi::version_0::SeriesExport;
15
use polars_plan::plans::python_df_to_rust;
16
use polars_utils::python_convert_registry::{FromPythonConvertRegistry, PythonConvertRegistry};
17
use pyo3::prelude::*;
18
use pyo3::{IntoPyObjectExt, intern};
19
20
use crate::Wrap;
21
use crate::dataframe::PyDataFrame;
22
use crate::lazyframe::PyLazyFrame;
23
use crate::map::lazy::call_lambda_with_series;
24
use crate::prelude::ObjectValue;
25
use crate::py_modules::{pl_df, pl_utils, polars, polars_rs};
26
use crate::series::PySeries;
27
28
fn python_function_caller_series(
29
s: &[Column],
30
output_dtype: Option<DataType>,
31
lambda: &Py<PyAny>,
32
) -> PolarsResult<Column> {
33
Python::attach(|py| call_lambda_with_series(py, s, output_dtype, lambda))
34
}
35
36
fn python_function_caller_df(df: DataFrame, lambda: &Py<PyAny>) -> PolarsResult<DataFrame> {
37
Python::attach(|py| {
38
let pypolars = polars(py).bind(py);
39
40
// create a PySeries struct/object for Python
41
let pydf = PyDataFrame::new(df);
42
// Wrap this PySeries object in the python side Series wrapper
43
let mut python_df_wrapper = pypolars
44
.getattr("wrap_df")
45
.unwrap()
46
.call1((pydf.clone(),))
47
.unwrap();
48
49
if !python_df_wrapper
50
.getattr("_df")
51
.unwrap()
52
.is_instance(polars_rs(py).getattr(py, "PyDataFrame").unwrap().bind(py))
53
.unwrap()
54
{
55
let pldf = pl_df(py).bind(py);
56
let width = pydf.width();
57
// Don't resize the Vec to avoid calling SeriesExport's Drop impl
58
// The import takes ownership and is responsible for dropping
59
let mut columns: Vec<SeriesExport> = Vec::with_capacity(width);
60
unsafe {
61
pydf._export_columns(columns.as_mut_ptr() as usize);
62
}
63
// Wrap this PyDataFrame object in the python side DataFrame wrapper
64
python_df_wrapper = pldf
65
.getattr("_import_columns")
66
.unwrap()
67
.call1((columns.as_mut_ptr() as usize, width))
68
.unwrap();
69
}
70
// call the lambda and get a python side df wrapper
71
let result_df_wrapper = lambda.call1(py, (python_df_wrapper,))?;
72
73
// unpack the wrapper in a PyDataFrame
74
let py_pydf = result_df_wrapper.getattr(py, "_df").map_err(|_| {
75
let pytype = result_df_wrapper.bind(py).get_type();
76
PolarsError::ComputeError(
77
format!("Expected 'LazyFrame.map' to return a 'DataFrame', got a '{pytype}'",)
78
.into(),
79
)
80
})?;
81
// Downcast to Rust
82
match py_pydf.extract::<PyDataFrame>(py) {
83
Ok(pydf) => Ok(pydf.df.into_inner()),
84
Err(_) => python_df_to_rust(py, result_df_wrapper.into_bound(py)),
85
}
86
})
87
}
88
89
fn warning_function(msg: &str, warning: PolarsWarning) {
90
Python::attach(|py| {
91
let warn_fn = pl_utils(py)
92
.bind(py)
93
.getattr(intern!(py, "_polars_warn"))
94
.unwrap();
95
96
if let Err(e) = warn_fn.call1((msg, Wrap(warning).into_pyobject(py).unwrap())) {
97
eprintln!("{e}")
98
}
99
});
100
}
101
102
static POLARS_REGISTRY_INIT_LOCK: OnceLock<()> = OnceLock::new();
103
104
/// # Safety
105
/// Caller must ensure that no other threads read the objects set by this registration.
106
pub unsafe fn register_startup_deps(catch_keyboard_interrupt: bool) {
107
// TODO: should we throw an error if we try to initialize while already initialized?
108
POLARS_REGISTRY_INIT_LOCK.get_or_init(|| {
109
set_polars_allow_extension(true);
110
111
// Stack frames can get really large in debug mode.
112
#[cfg(debug_assertions)]
113
{
114
recursive::set_minimum_stack_size(1024 * 1024);
115
recursive::set_stack_allocation_size(1024 * 1024 * 16);
116
}
117
118
// Register object type builder.
119
let object_builder = Box::new(|name: PlSmallStr, capacity: usize| {
120
Box::new(ObjectChunkedBuilder::<ObjectValue>::new(name, capacity))
121
as Box<dyn AnonymousObjectBuilder>
122
});
123
124
let object_converter = Arc::new(|av: AnyValue| {
125
let object = Python::attach(|py| ObjectValue {
126
inner: Wrap(av).into_py_any(py).unwrap(),
127
});
128
Box::new(object) as Box<dyn Any>
129
});
130
let pyobject_converter = Arc::new(|av: AnyValue| {
131
let object = Python::attach(|py| Wrap(av).into_py_any(py).unwrap());
132
Box::new(object) as Box<dyn Any>
133
});
134
fn object_array_getter(arr: &dyn Array, idx: usize) -> Option<AnyValue<'_>> {
135
let arr = arr.as_any().downcast_ref::<ObjectArray<ObjectValue>>().unwrap();
136
arr.get(idx).map(|v| AnyValue::Object(v))
137
}
138
139
polars_utils::python_convert_registry::register_converters(PythonConvertRegistry {
140
from_py: FromPythonConvertRegistry {
141
partition_target_cb_result: Arc::new(|py_f| {
142
Python::attach(|py| {
143
Ok(Box::new(
144
py_f.extract::<Wrap<polars_plan::dsl::PartitionTargetCallbackResult>>(
145
py,
146
)?
147
.0,
148
) as _)
149
})
150
}),
151
file_provider_result: Arc::new(|py_f| {
152
Python::attach(|py| {
153
Ok(Box::new(py_f.extract::<Wrap<FileProviderReturn>>(py)?.0) as _)
154
})
155
}),
156
series: Arc::new(|py_f| {
157
Python::attach(|py| {
158
Ok(Box::new(py_f.extract::<PySeries>(py)?.series.into_inner()) as _)
159
})
160
}),
161
df: Arc::new(|py_f| {
162
Python::attach(|py| {
163
Ok(Box::new(py_f.extract::<PyDataFrame>(py)?.df.into_inner()) as _)
164
})
165
}),
166
dsl_plan: Arc::new(|py_f| {
167
Python::attach(|py| {
168
Ok(Box::new(
169
py_f.extract::<PyLazyFrame>(py)?
170
.ldf
171
.into_inner()
172
.logical_plan,
173
) as _)
174
})
175
}),
176
schema: Arc::new(|py_f| {
177
Python::attach(|py| {
178
Ok(Box::new(py_f.extract::<Wrap<polars_core::schema::Schema>>(py)?.0) as _)
179
})
180
}),
181
},
182
to_py: polars_utils::python_convert_registry::ToPythonConvertRegistry {
183
df: Arc::new(|df| {
184
Python::attach(|py| {
185
PyDataFrame::new(df.downcast_ref::<DataFrame>().unwrap().clone())
186
.into_py_any(py)
187
})
188
}),
189
series: Arc::new(|series| {
190
Python::attach(|py| {
191
PySeries::new(series.downcast_ref::<Series>().unwrap().clone())
192
.into_py_any(py)
193
})
194
}),
195
dsl_plan: Arc::new(|dsl_plan| {
196
Python::attach(|py| {
197
PyLazyFrame::from(LazyFrame::from(
198
dsl_plan
199
.downcast_ref::<polars_plan::dsl::DslPlan>()
200
.unwrap()
201
.clone(),
202
))
203
.into_py_any(py)
204
})
205
}),
206
schema: Arc::new(|schema| {
207
Python::attach(|py| {
208
Wrap(
209
schema
210
.downcast_ref::<polars_core::schema::Schema>()
211
.unwrap()
212
.clone(),
213
)
214
.into_py_any(py)
215
})
216
}),
217
},
218
});
219
220
let object_size = size_of::<ObjectValue>();
221
let physical_dtype = ArrowDataType::FixedSizeBinary(object_size);
222
registry::register_object_builder(
223
object_builder,
224
object_converter,
225
pyobject_converter,
226
physical_dtype,
227
Arc::new(object_array_getter)
228
);
229
230
use crate::dataset::dataset_provider_funcs;
231
232
polars_plan::dsl::DATASET_PROVIDER_VTABLE.get_or_init(|| PythonDatasetProviderVTable {
233
name: dataset_provider_funcs::name,
234
schema: dataset_provider_funcs::schema,
235
to_dataset_scan: dataset_provider_funcs::to_dataset_scan,
236
});
237
238
// Register SERIES UDF.
239
python_dsl::CALL_COLUMNS_UDF_PYTHON = Some(python_function_caller_series);
240
// Register DATAFRAME UDF.
241
python_dsl::CALL_DF_UDF_PYTHON = Some(python_function_caller_df);
242
// Register warning function for `polars_warn!`.
243
polars_error::set_warning_function(warning_function);
244
245
if catch_keyboard_interrupt {
246
register_polars_keyboard_interrupt_hook();
247
}
248
249
use polars_core::datatypes::extension::UnknownExtensionTypeBehavior;
250
let behavior = match std::env::var("POLARS_UNKNOWN_EXTENSION_TYPE_BEHAVIOR").as_deref() {
251
Ok("load_as_storage") => UnknownExtensionTypeBehavior::LoadAsStorage,
252
Ok("load_as_extension") => UnknownExtensionTypeBehavior::LoadAsGeneric,
253
Ok("") | Err(_) => UnknownExtensionTypeBehavior::WarnAndLoadAsStorage,
254
_ => {
255
polars_warn!("Invalid value for 'POLARS_UNKNOWN_EXTENSION_TYPE_BEHAVIOR' environment variable. Expected one of 'load_as_storage' or 'load_as_extension'.");
256
UnknownExtensionTypeBehavior::WarnAndLoadAsStorage
257
},
258
};
259
polars_core::datatypes::extension::set_unknown_extension_type_behavior(behavior);
260
});
261
}
262
263