Path: blob/main/crates/polars-plan/src/dsl/serializable_plan.rs
7884 views
use polars_utils::unique_id::UniqueId;1use recursive::recursive;2use serde::{Deserialize, Serialize};3use slotmap::{SecondaryMap, SlotMap, new_key_type};45use super::*;67new_key_type! {8/// A key type for identifying DataFrame nodes in a serialized DSL plan.9pub(crate) struct DataFrameKey;1011/// A key type for identifying DslPlan nodes in a serialized DSL plan.12pub(crate) struct DslPlanKey;13}1415/// A representation of DslPlan that does not contain any `Arc` pointers, and16/// instead uses indices to refer to DataFrames and other DslPlan nodes.17///18/// This data structure mirrors the `DslPlan` enum, but uses `DataFrameKey` and19/// `DslPlanKey` to refer to DataFrames and other DslPlan nodes, respectively.20/// We it like this, because serde does not support the keeping of a global21/// state during (de)serialization. Instead, we do a manual conversion to a22/// serde-compatible representation, and then let serde handle the rest.23#[derive(Debug)]24#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]25pub(crate) struct SerializableDslPlan {26pub(crate) root: DslPlanKey,27pub(crate) dataframes: SlotMap<DataFrameKey, DataFrameSerdeWrap>,28pub(crate) dsl_plans: SlotMap<DslPlanKey, SerializableDslPlanNode>,29}3031#[derive(Debug, Serialize, Deserialize)]32pub(crate) enum SerializableDslPlanNode {33#[cfg(feature = "python")]34PythonScan {35options: crate::dsl::python_dsl::PythonOptionsDsl,36},37Filter {38input: DslPlanKey,39predicate: Expr,40},41Cache {42input: DslPlanKey,43id: UniqueId,44},45Scan {46sources: ScanSources,47unified_scan_args: Box<UnifiedScanArgs>,48scan_type: Box<FileScanDsl>,49},50DataFrameScan {51df: DataFrameKey,52schema: SchemaRef,53},54Select {55expr: Vec<Expr>,56input: DslPlanKey,57options: ProjectionOptions,58},59GroupBy {60input: DslPlanKey,61keys: Vec<Expr>,62aggs: Vec<Expr>,63predicates: Vec<Expr>,64maintain_order: bool,65options: Arc<GroupbyOptions>,66apply: Option<(PlanCallback<DataFrame, DataFrame>, SchemaRef)>,67},68Join {69input_left: DslPlanKey,70input_right: DslPlanKey,71left_on: Vec<Expr>,72right_on: Vec<Expr>,73predicates: Vec<Expr>,74options: Arc<JoinOptions>,75},76HStack {77input: DslPlanKey,78exprs: Vec<Expr>,79options: ProjectionOptions,80},81MatchToSchema {82input: DslPlanKey,83match_schema: SchemaRef,84per_column: Arc<[MatchToSchemaPerColumn]>,85extra_columns: ExtraColumnsPolicy,86},87PipeWithSchema {88input: Vec<DslPlanKey>,89callback: PlanCallback<(Vec<DslPlan>, Vec<SchemaRef>), DslPlan>,90},91#[cfg(feature = "pivot")]92Pivot {93input: DslPlanKey,94on: Selector,95on_columns: DataFrameKey,96index: Selector,97values: Selector,98agg: Expr,99maintain_order: bool,100separator: PlSmallStr,101},102Distinct {103input: DslPlanKey,104options: DistinctOptionsDSL,105},106Sort {107input: DslPlanKey,108by_column: Vec<Expr>,109slice: Option<(i64, usize)>,110sort_options: SortMultipleOptions,111},112Slice {113input: DslPlanKey,114offset: i64,115len: IdxSize,116},117MapFunction {118input: DslPlanKey,119function: DslFunction,120},121Union {122inputs: Vec<SerializableDslPlanNode>,123args: UnionArgs,124},125HConcat {126inputs: Vec<SerializableDslPlanNode>,127options: HConcatOptions,128},129ExtContext {130input: DslPlanKey,131contexts: Vec<SerializableDslPlanNode>,132},133Sink {134input: DslPlanKey,135payload: SinkType,136},137SinkMultiple {138inputs: Vec<SerializableDslPlanNode>,139},140#[cfg(feature = "merge_sorted")]141MergeSorted {142input_left: DslPlanKey,143input_right: DslPlanKey,144key: PlSmallStr,145},146IR {147dsl: DslPlanKey,148version: u32,149},150}151152#[derive(Debug, Default)]153struct SerializeArenas {154dataframes: SlotMap<DataFrameKey, DataFrameSerdeWrap>,155dataframes_keys_table: PlIndexMap<*const DataFrame, DataFrameKey>,156dsl_plans: SlotMap<DslPlanKey, SerializableDslPlanNode>,157dsl_plans_keys_table: PlIndexMap<*const DslPlan, DslPlanKey>,158}159160impl From<&DslPlan> for SerializableDslPlan {161fn from(plan: &DslPlan) -> Self {162let mut arenas = SerializeArenas::default();163let root_dsl_plan = convert_dsl_plan_to_serializable_plan(plan, &mut arenas);164165let root_key = arenas.dsl_plans.insert(root_dsl_plan);166SerializableDslPlan {167root: root_key,168dataframes: arenas.dataframes,169dsl_plans: arenas.dsl_plans,170}171}172}173174#[recursive]175fn convert_dsl_plan_to_serializable_plan(176plan: &DslPlan,177arenas: &mut SerializeArenas,178) -> SerializableDslPlanNode {179use {DslPlan as DP, SerializableDslPlanNode as SP};180181match plan {182#[cfg(feature = "python")]183DP::PythonScan { options } => SP::PythonScan {184options: options.clone(),185},186DP::Filter { input, predicate } => SP::Filter {187input: dsl_plan_key(input, arenas),188predicate: predicate.clone(),189},190DP::Cache { input, id } => SP::Cache {191input: dsl_plan_key(input, arenas),192id: *id,193},194DP::Scan {195sources,196unified_scan_args,197scan_type,198cached_ir: _,199} => SP::Scan {200sources: sources.clone(),201unified_scan_args: unified_scan_args.clone(),202scan_type: scan_type.clone(),203},204DP::DataFrameScan { df, schema } => SP::DataFrameScan {205df: dataframe_key(df, arenas),206schema: schema.clone(),207},208DP::Select {209expr,210input,211options,212} => SP::Select {213expr: expr.clone(),214input: dsl_plan_key(input, arenas),215options: *options,216},217DP::GroupBy {218input,219keys,220aggs,221predicates,222maintain_order,223options,224apply,225} => SP::GroupBy {226input: dsl_plan_key(input, arenas),227keys: keys.clone(),228aggs: aggs.clone(),229predicates: predicates.clone(),230maintain_order: *maintain_order,231options: options.clone(),232apply: apply.clone(),233},234DP::Join {235input_left,236input_right,237left_on,238right_on,239predicates,240options,241} => SP::Join {242input_left: dsl_plan_key(input_left, arenas),243input_right: dsl_plan_key(input_right, arenas),244left_on: left_on.clone(),245right_on: right_on.clone(),246predicates: predicates.clone(),247options: options.clone(),248},249DP::HStack {250input,251exprs,252options,253} => SP::HStack {254input: dsl_plan_key(input, arenas),255exprs: exprs.clone(),256options: *options,257},258DP::MatchToSchema {259input,260match_schema,261per_column,262extra_columns,263} => SP::MatchToSchema {264input: dsl_plan_key(input, arenas),265match_schema: match_schema.clone(),266per_column: per_column.clone(),267extra_columns: *extra_columns,268},269DP::PipeWithSchema { input, callback } => SP::PipeWithSchema {270input: input271.iter()272.map(|plan| dsl_plan_key_from_ref(plan, arenas))273.collect(),274callback: callback.clone(),275},276#[cfg(feature = "pivot")]277DP::Pivot {278input,279on,280on_columns,281index,282values,283agg,284maintain_order,285separator,286} => SP::Pivot {287input: dsl_plan_key(input, arenas),288on: on.clone(),289on_columns: dataframe_key(on_columns, arenas),290index: index.clone(),291values: values.clone(),292agg: agg.clone(),293maintain_order: *maintain_order,294separator: separator.clone(),295},296DP::Distinct { input, options } => SP::Distinct {297input: dsl_plan_key(input, arenas),298options: options.clone(),299},300DP::Sort {301input,302by_column,303slice,304sort_options,305} => SP::Sort {306input: dsl_plan_key(input, arenas),307by_column: by_column.clone(),308slice: *slice,309sort_options: sort_options.clone(),310},311DP::Slice { input, offset, len } => SP::Slice {312input: dsl_plan_key(input, arenas),313offset: *offset,314len: *len,315},316DP::MapFunction { input, function } => SP::MapFunction {317input: dsl_plan_key(input, arenas),318function: function.clone(),319},320DP::Union { inputs, args } => SP::Union {321inputs: inputs322.iter()323.map(|p| convert_dsl_plan_to_serializable_plan(p, arenas))324.collect(),325args: *args,326},327DP::HConcat { inputs, options } => SP::HConcat {328inputs: inputs329.iter()330.map(|p| convert_dsl_plan_to_serializable_plan(p, arenas))331.collect(),332options: *options,333},334DP::ExtContext { input, contexts } => SP::ExtContext {335input: dsl_plan_key(input, arenas),336contexts: contexts337.iter()338.map(|p| convert_dsl_plan_to_serializable_plan(p, arenas))339.collect(),340},341DP::Sink { input, payload } => SP::Sink {342input: dsl_plan_key(input, arenas),343payload: payload.clone(),344},345DP::SinkMultiple { inputs } => SP::SinkMultiple {346inputs: inputs347.iter()348.map(|p| convert_dsl_plan_to_serializable_plan(p, arenas))349.collect(),350},351#[cfg(feature = "merge_sorted")]352DP::MergeSorted {353input_left,354input_right,355key,356} => SP::MergeSorted {357input_left: dsl_plan_key(input_left, arenas),358input_right: dsl_plan_key(input_right, arenas),359key: key.clone(),360},361DP::IR {362dsl,363version: _,364node: _,365} => convert_dsl_plan_to_serializable_plan(dsl.as_ref(), arenas),366}367}368369fn dataframe_key(df: &Arc<DataFrame>, arenas: &mut SerializeArenas) -> DataFrameKey {370let ptr = Arc::as_ptr(df);371if let Some(key) = arenas.dataframes_keys_table.get(&ptr) {372*key373} else {374let key = arenas.dataframes.insert(DataFrameSerdeWrap(df.clone()));375arenas.dataframes_keys_table.insert(ptr, key);376key377}378}379380fn dsl_plan_key_from_ref(plan: &DslPlan, arenas: &mut SerializeArenas) -> DslPlanKey {381let ptr = plan as *const _;382if let Some(key) = arenas.dsl_plans_keys_table.get(&ptr) {383*key384} else {385let ser_plan = convert_dsl_plan_to_serializable_plan(plan, arenas);386let key = arenas.dsl_plans.insert(ser_plan);387arenas.dsl_plans_keys_table.insert(ptr, key);388key389}390}391392fn dsl_plan_key(plan: &Arc<DslPlan>, arenas: &mut SerializeArenas) -> DslPlanKey {393let ref_plan = Arc::as_ref(plan);394dsl_plan_key_from_ref(ref_plan, arenas)395}396397#[derive(Debug, Default)]398struct DeserializeArenas {399dataframes: SecondaryMap<DataFrameKey, DataFrameSerdeWrap>,400dsl_plans: SecondaryMap<DslPlanKey, Arc<DslPlan>>,401}402403impl TryFrom<&SerializableDslPlan> for DslPlan {404type Error = PolarsError;405406fn try_from(ser_dsl_plan: &SerializableDslPlan) -> Result<Self, Self::Error> {407let mut arenas = DeserializeArenas::default();408let root = ser_dsl_plan409.dsl_plans410.get(ser_dsl_plan.root)411.ok_or(polars_err!(ComputeError: "Could not find root DslPlan in serialized plan"))?;412try_convert_serializable_plan_to_dsl_plan(root, ser_dsl_plan, &mut arenas)413}414}415416#[recursive]417fn try_convert_serializable_plan_to_dsl_plan(418node: &SerializableDslPlanNode,419ser_dsl_plan: &SerializableDslPlan,420arenas: &mut DeserializeArenas,421) -> Result<DslPlan, PolarsError> {422use {DslPlan as DP, SerializableDslPlanNode as SP};423424match node {425#[cfg(feature = "python")]426SP::PythonScan { options } => Ok(DP::PythonScan {427options: options.clone(),428}),429SP::Filter { input, predicate } => Ok(DP::Filter {430input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,431predicate: predicate.clone(),432}),433SP::Cache { input, id } => Ok(DP::Cache {434input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,435id: *id,436}),437SP::Scan {438sources,439unified_scan_args,440scan_type,441} => Ok(DP::Scan {442sources: sources.clone(),443unified_scan_args: unified_scan_args.clone(),444scan_type: scan_type.clone(),445cached_ir: Default::default(),446}),447SP::DataFrameScan { df, schema } => Ok(DP::DataFrameScan {448df: get_dataframe(*df, ser_dsl_plan, arenas)?,449schema: schema.clone(),450}),451SP::Select {452expr,453input,454options,455} => Ok(DP::Select {456expr: expr.clone(),457input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,458options: *options,459}),460SP::GroupBy {461input,462keys,463aggs,464predicates,465maintain_order,466options,467apply,468} => Ok(DP::GroupBy {469input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,470keys: keys.clone(),471aggs: aggs.clone(),472predicates: predicates.clone(),473maintain_order: *maintain_order,474options: options.clone(),475apply: apply.clone(),476}),477SP::Join {478input_left,479input_right,480left_on,481right_on,482predicates,483options,484} => Ok(DP::Join {485input_left: get_dsl_plan(*input_left, ser_dsl_plan, arenas)?,486input_right: get_dsl_plan(*input_right, ser_dsl_plan, arenas)?,487left_on: left_on.clone(),488right_on: right_on.clone(),489predicates: predicates.clone(),490options: options.clone(),491}),492SP::HStack {493input,494exprs,495options,496} => Ok(DP::HStack {497input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,498exprs: exprs.clone(),499options: *options,500}),501SP::MatchToSchema {502input,503match_schema,504per_column,505extra_columns,506} => Ok(DP::MatchToSchema {507input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,508match_schema: match_schema.clone(),509per_column: per_column.clone(),510extra_columns: *extra_columns,511}),512SP::PipeWithSchema { input, callback } => Ok(DP::PipeWithSchema {513input: Arc::from(514input515.iter()516.map(|key| get_dsl_plan(*key, ser_dsl_plan, arenas).map(Arc::unwrap_or_clone))517.collect::<PolarsResult<Vec<_>>>()?,518),519callback: callback.clone(),520}),521#[cfg(feature = "pivot")]522SP::Pivot {523input,524on,525on_columns,526index,527values,528agg,529maintain_order,530separator,531} => Ok(DP::Pivot {532input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,533on: on.clone(),534on_columns: get_dataframe(*on_columns, ser_dsl_plan, arenas)?,535index: index.clone(),536values: values.clone(),537agg: agg.clone(),538maintain_order: *maintain_order,539separator: separator.clone(),540}),541SP::Distinct { input, options } => Ok(DP::Distinct {542input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,543options: options.clone(),544}),545SP::Sort {546input,547by_column,548slice,549sort_options,550} => Ok(DP::Sort {551input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,552by_column: by_column.clone(),553slice: *slice,554sort_options: sort_options.clone(),555}),556SP::Slice { input, offset, len } => Ok(DP::Slice {557input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,558offset: *offset,559len: *len,560}),561SP::MapFunction { input, function } => Ok(DP::MapFunction {562input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,563function: function.clone(),564}),565SP::Union { inputs, args } => Ok(DP::Union {566inputs: inputs567.iter()568.map(|node| try_convert_serializable_plan_to_dsl_plan(node, ser_dsl_plan, arenas))569.collect::<Result<Vec<_>, _>>()?,570args: *args,571}),572SP::HConcat { inputs, options } => Ok(DP::HConcat {573inputs: inputs574.iter()575.map(|node| try_convert_serializable_plan_to_dsl_plan(node, ser_dsl_plan, arenas))576.collect::<Result<Vec<_>, _>>()?,577options: *options,578}),579SP::ExtContext { input, contexts } => Ok(DP::ExtContext {580input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,581contexts: contexts582.iter()583.map(|node| try_convert_serializable_plan_to_dsl_plan(node, ser_dsl_plan, arenas))584.collect::<Result<Vec<_>, _>>()?,585}),586SP::Sink { input, payload } => Ok(DP::Sink {587input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,588payload: payload.clone(),589}),590SP::SinkMultiple { inputs } => Ok(DP::SinkMultiple {591inputs: inputs592.iter()593.map(|node| try_convert_serializable_plan_to_dsl_plan(node, ser_dsl_plan, arenas))594.collect::<Result<Vec<_>, _>>()?,595}),596#[cfg(feature = "merge_sorted")]597SP::MergeSorted {598input_left,599input_right,600key,601} => Ok(DP::MergeSorted {602input_left: get_dsl_plan(*input_left, ser_dsl_plan, arenas)?,603input_right: get_dsl_plan(*input_right, ser_dsl_plan, arenas)?,604key: key.clone(),605}),606SP::IR {607dsl: dsl_key,608version: _,609} => get_dsl_plan(*dsl_key, ser_dsl_plan, arenas).map(Arc::unwrap_or_clone),610}611}612613fn get_dataframe(614key: DataFrameKey,615ser_dsl_plan: &SerializableDslPlan,616arenas: &mut DeserializeArenas,617) -> Result<Arc<DataFrame>, PolarsError> {618if let Some(df) = arenas.dataframes.get(key) {619Ok(df.0.clone())620} else {621let df = ser_dsl_plan.dataframes.get(key).ok_or(polars_err!(622ComputeError: "Could not find DataFrame at index {:?} in serialized plan", key623))?;624arenas.dataframes.insert(key, df.clone());625Ok(df.0.clone())626}627}628629fn get_dsl_plan(630key: DslPlanKey,631ser_dsl_plan: &SerializableDslPlan,632arenas: &mut DeserializeArenas,633) -> Result<Arc<DslPlan>, PolarsError> {634if let Some(dsl_plan) = arenas.dsl_plans.get(key) {635Ok(dsl_plan.clone())636} else {637let node = ser_dsl_plan.dsl_plans.get(key).ok_or(polars_err!(638ComputeError: "Could not find DslPlan node at index {:?} in serialized plan", key639))?;640let dsl_plan = try_convert_serializable_plan_to_dsl_plan(node, ser_dsl_plan, arenas)?;641let arc_dsl_plan = Arc::new(dsl_plan);642arenas.dsl_plans.insert(key, arc_dsl_plan.clone());643Ok(arc_dsl_plan)644}645}646647/// Serialization wrapper that splits large serialized byte values into chunks.648#[derive(Debug, Clone)]649pub(crate) struct DataFrameSerdeWrap(Arc<DataFrame>);650651#[cfg(feature = "serde")]652mod _serde_impl {653use std::sync::Arc;654655use polars_core::frame::DataFrame;656use polars_utils::chunked_bytes_cursor::FixedSizeChunkedBytesCursor;657use serde::de::Error;658use serde::{Deserialize, Serialize};659660use super::DataFrameSerdeWrap;661662fn max_byte_slice_len() -> usize {663std::env::var("POLARS_SERIALIZE_LAZYFRAME_MAX_BYTE_SLICE_LEN")664.as_deref()665.map_or(666usize::try_from(u32::MAX).unwrap(), // Limit for rmp_serde667|x| x.parse().unwrap(),668)669}670671impl Serialize for DataFrameSerdeWrap {672fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>673where674S: serde::Serializer,675{676use serde::ser::Error;677678let mut bytes: Vec<u8> = vec![];679self.0680.as_ref()681.clone()682.serialize_into_writer(&mut bytes)683.map_err(S::Error::custom)?;684685serializer.collect_seq(bytes.chunks(max_byte_slice_len()))686}687}688689impl<'de> Deserialize<'de> for DataFrameSerdeWrap {690fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>691where692D: serde::Deserializer<'de>,693{694let bytes: Vec<Vec<u8>> = Vec::deserialize(deserializer)?;695696let result = match bytes.as_slice() {697[v] => DataFrame::deserialize_from_reader(&mut std::io::Cursor::new(v.as_slice())),698_ => DataFrame::deserialize_from_reader(699&mut FixedSizeChunkedBytesCursor::try_new(bytes.as_slice()).unwrap(),700),701};702703result704.map(|x| DataFrameSerdeWrap(Arc::new(x)))705.map_err(D::Error::custom)706}707}708}709710#[cfg(test)]711mod tests {712use super::*;713714#[test]715fn test_dsl_plan_serialization() {716let name = || "a".into();717let df = Arc::new(718DataFrame::new(vec![Column::new(name(), Series::new(name(), &[1, 2, 3]))]).unwrap(),719);720let dfscan = Arc::new(DslPlan::DataFrameScan {721df: df.clone(),722schema: df.schema().clone(),723});724let join_options = JoinOptions {725allow_parallel: true,726force_parallel: false,727..Default::default()728};729let lf = DslPlan::Join {730input_left: dfscan.clone(),731input_right: dfscan,732left_on: vec![Expr::Column(name())],733right_on: vec![Expr::Column(name())],734predicates: Default::default(),735options: Arc::new(join_options),736};737let mut buffer: Vec<u8> = Vec::new();738lf.serialize_versioned(&mut buffer, Default::default())739.unwrap();740let mut reader: &[u8] = &buffer;741let deserialized = DslPlan::deserialize_versioned(&mut reader).unwrap();742assert_eq!(format!("{lf:?}"), format!("{deserialized:?}"));743}744}745746747