Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-python/src/on_startup.rs
8396 views
1
#![allow(unsafe_op_in_unsafe_fn)]
2
use std::any::Any;
3
use std::sync::OnceLock;
4
5
use arrow::array::Array;
6
use polars::chunked_array::object::ObjectArray;
7
use polars::prelude::file_provider::FileProviderReturn;
8
use polars::prelude::*;
9
use polars_core::chunked_array::object::builder::ObjectChunkedBuilder;
10
use polars_core::chunked_array::object::registry::AnonymousObjectBuilder;
11
use polars_core::chunked_array::object::{registry, set_polars_allow_extension};
12
use polars_error::PolarsWarning;
13
use polars_error::signals::register_polars_keyboard_interrupt_hook;
14
use polars_ffi::version_0::SeriesExport;
15
use polars_plan::plans::python_df_to_rust;
16
use polars_utils::python_convert_registry::{FromPythonConvertRegistry, PythonConvertRegistry};
17
use pyo3::prelude::*;
18
use pyo3::{IntoPyObjectExt, intern};
19
20
use crate::Wrap;
21
use crate::dataframe::PyDataFrame;
22
use crate::lazyframe::PyLazyFrame;
23
use crate::map::lazy::call_lambda_with_series;
24
use crate::prelude::ObjectValue;
25
use crate::py_modules::{pl_df, pl_utils, polars, polars_rs};
26
use crate::series::PySeries;
27
28
fn python_function_caller_series(
29
s: &[Column],
30
output_dtype: Option<DataType>,
31
lambda: &Py<PyAny>,
32
) -> PolarsResult<Column> {
33
Python::attach(|py| call_lambda_with_series(py, s, output_dtype, lambda))
34
}
35
36
fn python_function_caller_df(df: DataFrame, lambda: &Py<PyAny>) -> PolarsResult<DataFrame> {
37
Python::attach(|py| {
38
let pypolars = polars(py).bind(py);
39
40
// create a PySeries struct/object for Python
41
let pydf = PyDataFrame::new(df);
42
// Wrap this PySeries object in the python side Series wrapper
43
let mut python_df_wrapper = pypolars
44
.getattr("wrap_df")
45
.unwrap()
46
.call1((pydf.clone(),))
47
.unwrap();
48
49
if !python_df_wrapper
50
.getattr("_df")
51
.unwrap()
52
.is_instance(polars_rs(py).getattr(py, "PyDataFrame").unwrap().bind(py))
53
.unwrap()
54
{
55
let pldf = pl_df(py).bind(py);
56
let width = pydf.width();
57
// Don't resize the Vec to avoid calling SeriesExport's Drop impl
58
// The import takes ownership and is responsible for dropping
59
let mut columns: Vec<SeriesExport> = Vec::with_capacity(width);
60
unsafe {
61
pydf._export_columns(columns.as_mut_ptr() as usize);
62
}
63
// Wrap this PyDataFrame object in the python side DataFrame wrapper
64
python_df_wrapper = pldf
65
.getattr("_import_columns")
66
.unwrap()
67
.call1((columns.as_mut_ptr() as usize, width))
68
.unwrap();
69
}
70
// call the lambda and get a python side df wrapper
71
let result_df_wrapper = lambda.call1(py, (python_df_wrapper,))?;
72
73
// unpack the wrapper in a PyDataFrame
74
let py_pydf = result_df_wrapper.getattr(py, "_df").map_err(|_| {
75
let pytype = result_df_wrapper.bind(py).get_type();
76
PolarsError::ComputeError(
77
format!("Expected 'LazyFrame.map' to return a 'DataFrame', got a '{pytype}'",)
78
.into(),
79
)
80
})?;
81
// Downcast to Rust
82
match py_pydf.extract::<PyDataFrame>(py) {
83
Ok(pydf) => Ok(pydf.df.into_inner()),
84
Err(_) => python_df_to_rust(py, result_df_wrapper.into_bound(py)),
85
}
86
})
87
}
88
89
fn warning_function(msg: &str, warning: PolarsWarning) {
90
Python::attach(|py| {
91
let warn_fn = pl_utils(py)
92
.bind(py)
93
.getattr(intern!(py, "_polars_warn"))
94
.unwrap();
95
96
if let Err(e) = warn_fn.call1((msg, Wrap(warning).into_pyobject(py).unwrap())) {
97
eprintln!("{e}")
98
}
99
});
100
}
101
102
static POLARS_REGISTRY_INIT_LOCK: OnceLock<()> = OnceLock::new();
103
104
/// # Safety
105
/// Caller must ensure that no other threads read the objects set by this registration.
106
pub unsafe fn register_startup_deps(catch_keyboard_interrupt: bool) {
107
// TODO: should we throw an error if we try to initialize while already initialized?
108
POLARS_REGISTRY_INIT_LOCK.get_or_init(|| {
109
set_polars_allow_extension(true);
110
111
// Stack frames can get really large in debug mode.
112
#[cfg(debug_assertions)]
113
{
114
recursive::set_minimum_stack_size(1024 * 1024);
115
recursive::set_stack_allocation_size(1024 * 1024 * 16);
116
}
117
118
#[cfg(feature = "backtrace_filter")]
119
{
120
use std::path::Path;
121
use color_backtrace::{BacktracePrinter, default_output_stream, default_is_dependency_frame, Frame, ColorScheme};
122
use color_backtrace::termcolor::{ColorSpec, Color};
123
124
let polars_base_path = || {
125
let on_startup = Path::new(file!()).canonicalize().ok()?;
126
let src = on_startup.parent()?;
127
let polars_python = src.parent()?;
128
let crates = polars_python.parent()?;
129
let root = crates.parent()?;
130
Some(root.to_path_buf())
131
};
132
133
let mut btp = BacktracePrinter::default();
134
if let Some(bp) = polars_base_path() {
135
btp = btp.dependency_predicate(Box::new(move |frame: &Frame| -> bool {
136
if let Some(file) = frame.filename.as_ref().and_then(|f| f.canonicalize().ok()) {
137
!file.starts_with(&bp)
138
} else {
139
default_is_dependency_frame(frame)
140
}
141
}));
142
}
143
144
let mut color_scheme = ColorScheme::classic();
145
color_scheme.dependency_code = ColorSpec::new();
146
color_scheme.dependency_code.set_dimmed(true);
147
color_scheme.dependency_code = color_scheme.dependency_code_hash.clone();
148
color_scheme.crate_code = ColorSpec::new();
149
color_scheme.crate_code.set_fg(Some(Color::Blue));
150
color_scheme.crate_code_hash = color_scheme.crate_code.clone();
151
152
btp
153
.color_scheme(color_scheme)
154
.install(default_output_stream());
155
}
156
157
// Register object type builder.
158
let object_builder = Box::new(|name: PlSmallStr, capacity: usize| {
159
Box::new(ObjectChunkedBuilder::<ObjectValue>::new(name, capacity))
160
as Box<dyn AnonymousObjectBuilder>
161
});
162
163
let object_converter = Arc::new(|av: AnyValue| {
164
let object = Python::attach(|py| ObjectValue {
165
inner: Wrap(av).into_py_any(py).unwrap(),
166
});
167
Box::new(object) as Box<dyn Any>
168
});
169
let pyobject_converter = Arc::new(|av: AnyValue| {
170
let object = Python::attach(|py| Wrap(av).into_py_any(py).unwrap());
171
Box::new(object) as Box<dyn Any>
172
});
173
fn object_array_getter(arr: &dyn Array, idx: usize) -> Option<AnyValue<'_>> {
174
let arr = arr.as_any().downcast_ref::<ObjectArray<ObjectValue>>().unwrap();
175
arr.get(idx).map(|v| AnyValue::Object(v))
176
}
177
178
polars_utils::python_convert_registry::register_converters(PythonConvertRegistry {
179
from_py: FromPythonConvertRegistry {
180
file_provider_result: Arc::new(|py_f| {
181
Python::attach(|py| {
182
Ok(Box::new(py_f.extract::<Wrap<FileProviderReturn>>(py)?.0) as _)
183
})
184
}),
185
series: Arc::new(|py_f| {
186
Python::attach(|py| {
187
Ok(Box::new(py_f.extract::<PySeries>(py)?.series.into_inner()) as _)
188
})
189
}),
190
df: Arc::new(|py_f| {
191
Python::attach(|py| {
192
Ok(Box::new(py_f.extract::<PyDataFrame>(py)?.df.into_inner()) as _)
193
})
194
}),
195
dsl_plan: Arc::new(|py_f| {
196
Python::attach(|py| {
197
Ok(Box::new(
198
py_f.extract::<PyLazyFrame>(py)?
199
.ldf
200
.into_inner()
201
.logical_plan,
202
) as _)
203
})
204
}),
205
schema: Arc::new(|py_f| {
206
Python::attach(|py| {
207
Ok(Box::new(py_f.extract::<Wrap<polars_core::schema::Schema>>(py)?.0) as _)
208
})
209
}),
210
},
211
to_py: polars_utils::python_convert_registry::ToPythonConvertRegistry {
212
df: Arc::new(|df| {
213
Python::attach(|py| {
214
PyDataFrame::new(df.downcast_ref::<DataFrame>().unwrap().clone())
215
.into_py_any(py)
216
})
217
}),
218
series: Arc::new(|series| {
219
Python::attach(|py| {
220
PySeries::new(series.downcast_ref::<Series>().unwrap().clone())
221
.into_py_any(py)
222
})
223
}),
224
dsl_plan: Arc::new(|dsl_plan| {
225
Python::attach(|py| {
226
PyLazyFrame::from(LazyFrame::from(
227
dsl_plan
228
.downcast_ref::<polars_plan::dsl::DslPlan>()
229
.unwrap()
230
.clone(),
231
))
232
.into_py_any(py)
233
})
234
}),
235
schema: Arc::new(|schema| {
236
Python::attach(|py| {
237
Wrap(
238
schema
239
.downcast_ref::<polars_core::schema::Schema>()
240
.unwrap()
241
.clone(),
242
)
243
.into_py_any(py)
244
})
245
}),
246
},
247
});
248
249
let object_size = size_of::<ObjectValue>();
250
let physical_dtype = ArrowDataType::FixedSizeBinary(object_size);
251
registry::register_object_builder(
252
object_builder,
253
object_converter,
254
pyobject_converter,
255
physical_dtype,
256
Arc::new(object_array_getter)
257
);
258
259
use crate::dataset::dataset_provider_funcs;
260
261
polars_plan::dsl::DATASET_PROVIDER_VTABLE.get_or_init(|| PythonDatasetProviderVTable {
262
name: dataset_provider_funcs::name,
263
schema: dataset_provider_funcs::schema,
264
to_dataset_scan: dataset_provider_funcs::to_dataset_scan,
265
});
266
267
// Register SERIES UDF.
268
python_dsl::CALL_COLUMNS_UDF_PYTHON = Some(python_function_caller_series);
269
// Register DATAFRAME UDF.
270
python_dsl::CALL_DF_UDF_PYTHON = Some(python_function_caller_df);
271
// Register warning function for `polars_warn!`.
272
polars_error::set_warning_function(warning_function);
273
274
if catch_keyboard_interrupt {
275
register_polars_keyboard_interrupt_hook();
276
}
277
278
use polars_core::datatypes::extension::UnknownExtensionTypeBehavior;
279
let behavior = match std::env::var("POLARS_UNKNOWN_EXTENSION_TYPE_BEHAVIOR").as_deref() {
280
Ok("load_as_storage") => UnknownExtensionTypeBehavior::LoadAsStorage,
281
Ok("load_as_extension") => UnknownExtensionTypeBehavior::LoadAsGeneric,
282
Ok("") | Err(_) => UnknownExtensionTypeBehavior::WarnAndLoadAsStorage,
283
_ => {
284
polars_warn!("Invalid value for 'POLARS_UNKNOWN_EXTENSION_TYPE_BEHAVIOR' environment variable. Expected one of 'load_as_storage' or 'load_as_extension'.");
285
UnknownExtensionTypeBehavior::WarnAndLoadAsStorage
286
},
287
};
288
polars_core::datatypes::extension::set_unknown_extension_type_behavior(behavior);
289
});
290
}
291
292