use std::fmt;
use std::io::{Read, Write};
use std::sync::{Arc, Mutex};
use polars_utils::arena::Node;
#[cfg(feature = "serde")]
use polars_utils::pl_serialize;
use polars_utils::unique_id::UniqueId;
use recursive::recursive;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use super::*;
pub const DSL_VERSION: (u16, u16) = (23, 0);
const DSL_MAGIC_BYTES: &[u8] = b"DSL_VERSION";
const DSL_SCHEMA_HASH: SchemaHash<'static> = SchemaHash::from_hash_file();
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub enum DslPlan {
#[cfg(feature = "python")]
PythonScan {
options: crate::dsl::python_dsl::PythonOptionsDsl,
},
Filter {
input: Arc<DslPlan>,
predicate: Expr,
},
Cache {
input: Arc<DslPlan>,
id: UniqueId,
},
Scan {
sources: ScanSources,
unified_scan_args: Box<UnifiedScanArgs>,
scan_type: Box<FileScanDsl>,
#[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
cached_ir: Arc<Mutex<Option<IR>>>,
},
DataFrameScan {
df: Arc<DataFrame>,
schema: SchemaRef,
},
Select {
expr: Vec<Expr>,
input: Arc<DslPlan>,
options: ProjectionOptions,
},
GroupBy {
input: Arc<DslPlan>,
keys: Vec<Expr>,
aggs: Vec<Expr>,
maintain_order: bool,
options: Arc<GroupbyOptions>,
apply: Option<(PlanCallback<DataFrame, DataFrame>, SchemaRef)>,
},
Join {
input_left: Arc<DslPlan>,
input_right: Arc<DslPlan>,
left_on: Vec<Expr>,
right_on: Vec<Expr>,
predicates: Vec<Expr>,
options: Arc<JoinOptions>,
},
HStack {
input: Arc<DslPlan>,
exprs: Vec<Expr>,
options: ProjectionOptions,
},
MatchToSchema {
input: Arc<DslPlan>,
match_schema: SchemaRef,
per_column: Arc<[MatchToSchemaPerColumn]>,
extra_columns: ExtraColumnsPolicy,
},
PipeWithSchema {
input: Arc<DslPlan>,
callback: PlanCallback<(DslPlan, Schema), DslPlan>,
},
Distinct {
input: Arc<DslPlan>,
options: DistinctOptionsDSL,
},
Sort {
input: Arc<DslPlan>,
by_column: Vec<Expr>,
slice: Option<(i64, usize)>,
sort_options: SortMultipleOptions,
},
Slice {
input: Arc<DslPlan>,
offset: i64,
len: IdxSize,
},
MapFunction {
input: Arc<DslPlan>,
function: DslFunction,
},
Union {
inputs: Vec<DslPlan>,
args: UnionArgs,
},
HConcat {
inputs: Vec<DslPlan>,
options: HConcatOptions,
},
ExtContext {
input: Arc<DslPlan>,
contexts: Vec<DslPlan>,
},
Sink {
input: Arc<DslPlan>,
payload: SinkType,
},
SinkMultiple {
inputs: Vec<DslPlan>,
},
#[cfg(feature = "merge_sorted")]
MergeSorted {
input_left: Arc<DslPlan>,
input_right: Arc<DslPlan>,
key: PlSmallStr,
},
IR {
dsl: Arc<DslPlan>,
version: u32,
#[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
node: Option<Node>,
},
}
impl Clone for DslPlan {
#[rustfmt::skip]
#[allow(clippy::clone_on_copy)]
#[recursive]
fn clone(&self) -> Self {
match self {
#[cfg(feature = "python")]
Self::PythonScan { options } => Self::PythonScan { options: options.clone() },
Self::Filter { input, predicate } => Self::Filter { input: input.clone(), predicate: predicate.clone() },
Self::Cache { input, id } => Self::Cache { input: input.clone(), id: *id },
Self::Scan { sources, unified_scan_args, scan_type, cached_ir } => Self::Scan { sources: sources.clone(), unified_scan_args: unified_scan_args.clone(), scan_type: scan_type.clone(), cached_ir: cached_ir.clone() },
Self::DataFrameScan { df, schema, } => Self::DataFrameScan { df: df.clone(), schema: schema.clone(), },
Self::Select { expr, input, options } => Self::Select { expr: expr.clone(), input: input.clone(), options: options.clone() },
Self::GroupBy { input, keys, aggs, apply, maintain_order, options } => Self::GroupBy { input: input.clone(), keys: keys.clone(), aggs: aggs.clone(), apply: apply.clone(), maintain_order: maintain_order.clone(), options: options.clone() },
Self::Join { input_left, input_right, left_on, right_on, predicates, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone(), predicates: predicates.clone() },
Self::HStack { input, exprs, options } => Self::HStack { input: input.clone(), exprs: exprs.clone(), options: options.clone() },
Self::MatchToSchema { input, match_schema, per_column, extra_columns } => Self::MatchToSchema { input: input.clone(), match_schema: match_schema.clone(), per_column: per_column.clone(), extra_columns: *extra_columns },
Self::PipeWithSchema { input, callback } => Self::PipeWithSchema { input: input.clone(), callback: callback.clone() },
Self::Distinct { input, options } => Self::Distinct { input: input.clone(), options: options.clone() },
Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() },
Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() },
Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() },
Self::Union { inputs, args} => Self::Union { inputs: inputs.clone(), args: args.clone() },
Self::HConcat { inputs, options } => Self::HConcat { inputs: inputs.clone(), options: options.clone() },
Self::ExtContext { input, contexts, } => Self::ExtContext { input: input.clone(), contexts: contexts.clone() },
Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() },
Self::SinkMultiple { inputs } => Self::SinkMultiple { inputs: inputs.clone() },
#[cfg(feature = "merge_sorted")]
Self::MergeSorted { input_left, input_right, key } => Self::MergeSorted { input_left: input_left.clone(), input_right: input_right.clone(), key: key.clone() },
Self::IR {node, dsl, version} => Self::IR {node: *node, dsl: dsl.clone(), version: *version},
}
}
}
impl Default for DslPlan {
fn default() -> Self {
let df = DataFrame::empty();
let schema = df.schema().clone();
DslPlan::DataFrameScan {
df: Arc::new(df),
schema,
}
}
}
#[derive(Default, Clone, Copy)]
pub struct PlanSerializationContext {
pub use_cloudpickle: bool,
}
impl DslPlan {
pub fn describe(&self) -> PolarsResult<String> {
Ok(self.clone().to_alp()?.describe())
}
pub fn describe_tree_format(&self) -> PolarsResult<String> {
Ok(self.clone().to_alp()?.describe_tree_format())
}
pub fn display(&self) -> PolarsResult<impl fmt::Display> {
struct DslPlanDisplay(IRPlan);
impl fmt::Display for DslPlanDisplay {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0.as_ref().display(), f)
}
}
Ok(DslPlanDisplay(self.clone().to_alp()?))
}
pub fn to_alp(self) -> PolarsResult<IRPlan> {
let mut lp_arena = Arena::with_capacity(16);
let mut expr_arena = Arena::with_capacity(16);
let node = to_alp(
self,
&mut expr_arena,
&mut lp_arena,
&mut OptFlags::default(),
)?;
let plan = IRPlan::new(node, lp_arena, expr_arena);
Ok(plan)
}
#[cfg(feature = "serde")]
pub fn serialize_versioned<W: Write>(
&self,
mut writer: W,
ctx: PlanSerializationContext,
) -> PolarsResult<()> {
let le_major = DSL_VERSION.0.to_le_bytes();
let le_minor = DSL_VERSION.1.to_le_bytes();
polars_utils::pl_serialize::USE_CLOUDPICKLE.set(ctx.use_cloudpickle);
writer.write_all(DSL_MAGIC_BYTES)?;
writer.write_all(&le_major)?;
writer.write_all(&le_minor)?;
writer.write_all(DSL_SCHEMA_HASH.as_bytes())?;
pl_serialize::serialize_dsl(writer, self)
}
#[cfg(feature = "serde")]
pub fn deserialize_versioned<R: Read>(mut reader: R) -> PolarsResult<Self> {
const MAGIC_LEN: usize = DSL_MAGIC_BYTES.len();
let mut version_magic = [0u8; MAGIC_LEN + 4];
reader
.read_exact(&mut version_magic)
.map_err(|e| polars_err!(ComputeError: "failed to read incoming DSL_VERSION: {e}"))?;
if &version_magic[..MAGIC_LEN] != DSL_MAGIC_BYTES {
polars_bail!(ComputeError: "dsl magic bytes not found")
}
let major = u16::from_le_bytes(version_magic[MAGIC_LEN..MAGIC_LEN + 2].try_into().unwrap());
let minor = u16::from_le_bytes(
version_magic[MAGIC_LEN + 2..MAGIC_LEN + 4]
.try_into()
.unwrap(),
);
const MAJOR: u16 = DSL_VERSION.0;
const MINOR: u16 = DSL_VERSION.1;
if polars_core::config::verbose() {
eprintln!(
"incoming DSL_VERSION: {major}.{minor}, deserializer DSL_VERSION: {MAJOR}.{MINOR}"
);
}
if major != MAJOR {
polars_bail!(ComputeError:
"deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is not compatible with this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}",
"error: can't deserialize DSL with a different major version"
);
}
if minor > MINOR {
polars_bail!(ComputeError:
"deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is not compatible with this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}",
"error: can't deserialize DSL with a higher minor version"
);
}
let mut schema_hash = [0_u8; SCHEMA_HASH_LEN];
reader.read_exact(&mut schema_hash).map_err(
|e| polars_err!(ComputeError: "failed to read incoming DSL_SCHEMA_HASH: {e}"),
)?;
let incoming_hash = SchemaHash::new(&schema_hash).ok_or_else(
|| polars_err!(ComputeError: "failed to read incoming DSL schema hash, not a valid hex string")
)?;
if polars_core::config::verbose() {
eprintln!(
"incoming DSL_SCHEMA_HASH: {incoming_hash}, deserializer DSL_SCHEMA_HASH: {DSL_SCHEMA_HASH}"
);
}
if incoming_hash != DSL_SCHEMA_HASH {
polars_bail!(ComputeError:
"deserialization failed\n\ngiven DSL_SCHEMA_HASH: {incoming_hash} is not compatible with this Polars version which uses DSL_SCHEMA_HASH: {DSL_SCHEMA_HASH}\n{}",
"error: can't deserialize DSL with incompatible schema"
);
}
pl_serialize::deserialize_dsl(reader)
.map_err(|e| polars_err!(ComputeError: "deserialization failed\n\nerror: {e}"))
}
#[cfg(feature = "dsl-schema")]
pub fn dsl_schema() -> schemars::schema::RootSchema {
use schemars::r#gen::SchemaSettings;
use schemars::schema::SchemaObject;
use schemars::visit::{Visitor, visit_schema_object};
#[derive(Clone, Copy, Debug)]
struct MyVisitor;
impl Visitor for MyVisitor {
fn visit_schema_object(&mut self, schema: &mut SchemaObject) {
if schema.metadata.is_some() {
schema.metadata().description = None;
}
visit_schema_object(self, schema);
}
}
let mut schema = SchemaSettings::default()
.with_visitor(MyVisitor)
.into_generator()
.into_root_schema_for::<DslPlan>();
schema
.schema
.extensions
.insert("hash".into(), DSL_SCHEMA_HASH.to_string().into());
schema
}
}
const SCHEMA_HASH_LEN: usize = 64;
struct SchemaHash<'a>(&'a str);
impl SchemaHash<'static> {
const fn from_hash_file() -> Self {
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/dsl-schema.sha256"));
Self::new(bytes).expect("not a valid hex string")
}
}
impl<'a> SchemaHash<'a> {
const fn new(bytes: &'a [u8; SCHEMA_HASH_LEN]) -> Option<Self> {
let mut i = 0;
while i < bytes.len() {
if !bytes[i].is_ascii_hexdigit() {
return None;
};
i += 1;
}
match str::from_utf8(bytes) {
Ok(hash) => Some(Self(hash)),
Err(_) => unreachable!(),
}
}
fn as_bytes(&self) -> &'a [u8; SCHEMA_HASH_LEN] {
self.0.as_bytes().try_into().unwrap()
}
}
impl PartialEq for SchemaHash<'_> {
fn eq(&self, other: &Self) -> bool {
self.0.eq_ignore_ascii_case(other.0)
}
}
impl std::fmt::Display for SchemaHash<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}