Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-python/src/extension.rs
7884 views
1
use std::any::Any;
2
use std::borrow::Cow;
3
use std::hash::{BuildHasher, Hash, Hasher};
4
use std::sync::Arc;
5
6
use polars::prelude::PlFixedStateQuality;
7
use polars::prelude::extension::{register_extension_type, unregister_extension_type};
8
use polars_core::datatypes::DataType;
9
use polars_core::datatypes::extension::{ExtensionTypeFactory, ExtensionTypeImpl};
10
use pyo3::prelude::*;
11
12
use crate::prelude::Wrap;
13
use crate::utils::to_py_err;
14
15
struct PyExtensionTypeFactory {
16
cls: Arc<Py<PyAny>>,
17
}
18
19
#[derive(Clone)]
20
struct PyExtensionTypeImpl {
21
name: String,
22
display: String,
23
metadata: Option<String>,
24
}
25
26
impl ExtensionTypeFactory for PyExtensionTypeFactory {
27
fn create_type_instance(
28
&self,
29
name: &str,
30
storage: &DataType,
31
metadata: Option<&str>,
32
) -> Box<dyn ExtensionTypeImpl> {
33
Python::attach(|py| {
34
let typ_obj = self
35
.cls
36
.bind(py)
37
.call_method1("ext_from_params", (name, &Wrap(storage.clone()), metadata))
38
.unwrap();
39
40
let display = typ_obj
41
.call_method0("_string_repr")
42
.unwrap()
43
.extract()
44
.unwrap();
45
let metadata = typ_obj
46
.call_method0("ext_metadata")
47
.unwrap()
48
.extract()
49
.unwrap();
50
51
Box::new(PyExtensionTypeImpl {
52
name: name.to_string(),
53
display,
54
metadata,
55
})
56
})
57
}
58
}
59
60
impl ExtensionTypeImpl for PyExtensionTypeImpl {
61
fn name(&self) -> Cow<'_, str> {
62
Cow::Borrowed(&self.name)
63
}
64
65
fn serialize_metadata(&self) -> Option<Cow<'_, str>> {
66
self.metadata.as_deref().map(Cow::Borrowed)
67
}
68
69
fn dyn_clone(&self) -> Box<dyn ExtensionTypeImpl> {
70
Box::new(self.clone())
71
}
72
73
fn dyn_eq(&self, other: &dyn ExtensionTypeImpl) -> bool {
74
let Some(other) = (other as &dyn Any).downcast_ref::<PyExtensionTypeImpl>() else {
75
return false;
76
};
77
78
self.name == other.name && self.metadata == other.metadata
79
}
80
81
fn dyn_hash(&self) -> u64 {
82
let mut hasher = PlFixedStateQuality::default().build_hasher();
83
self.name.hash(&mut hasher);
84
self.metadata.hash(&mut hasher);
85
hasher.finish()
86
}
87
88
fn dyn_display(&self) -> Cow<'_, str> {
89
Cow::Borrowed(&self.display)
90
}
91
92
fn dyn_debug(&self) -> Cow<'_, str> {
93
if let Some(md) = &self.metadata {
94
Cow::Owned(format!(
95
"PyExtensionType(name='{}', metadata='{}')",
96
self.name, md
97
))
98
} else {
99
Cow::Owned(format!("PyExtensionType(name='{}')", self.name))
100
}
101
}
102
}
103
104
#[pyfunction]
105
pub fn _register_extension_type(name: &str, cls: Option<&Bound<PyAny>>) -> PyResult<()> {
106
register_extension_type(
107
name,
108
cls.map(|c| {
109
Arc::new(PyExtensionTypeFactory {
110
cls: Arc::new(c.clone().unbind()),
111
}) as _
112
}),
113
)
114
.map_err(to_py_err)
115
}
116
117
#[pyfunction]
118
pub fn _unregister_extension_type(name: &str) -> PyResult<()> {
119
unregister_extension_type(name).map(drop).map_err(to_py_err)
120
}
121
122