Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-stream/src/nodes/rle.rs
6939 views
1
use arrow::array::builder::ShareStrategy;
2
use polars_core::frame::DataFrame;
3
use polars_core::prelude::{
4
AnyValue, DataType, Field, IDX_DTYPE, IntoColumn, NamedFrom, StructChunked,
5
};
6
use polars_core::scalar::Scalar;
7
use polars_core::series::Series;
8
use polars_core::series::builder::SeriesBuilder;
9
use polars_error::PolarsResult;
10
use polars_ops::series::{RLE_LENGTH_COLUMN_NAME, RLE_VALUE_COLUMN_NAME};
11
use polars_utils::IdxSize;
12
use polars_utils::pl_str::PlSmallStr;
13
14
use super::ComputeNode;
15
use crate::async_executor::{JoinHandle, TaskPriority, TaskScope};
16
use crate::execute::StreamingExecutionState;
17
use crate::graph::PortState;
18
use crate::morsel::{Morsel, MorselSeq, SourceToken};
19
use crate::pipe::{RecvPort, SendPort};
20
21
pub struct RleNode {
22
name: PlSmallStr,
23
dtype: DataType,
24
25
seq: MorselSeq,
26
27
// Invariant: last == None <=> last_length == 0
28
last_length: IdxSize,
29
last: Option<AnyValue<'static>>,
30
}
31
32
impl RleNode {
33
pub fn new(name: PlSmallStr, dtype: DataType) -> Self {
34
Self {
35
name,
36
dtype,
37
seq: MorselSeq::default(),
38
last_length: 0,
39
last: None,
40
}
41
}
42
}
43
44
impl ComputeNode for RleNode {
45
fn name(&self) -> &str {
46
"rle"
47
}
48
49
fn update_state(
50
&mut self,
51
recv: &mut [PortState],
52
send: &mut [PortState],
53
_state: &StreamingExecutionState,
54
) -> PolarsResult<()> {
55
assert!(recv.len() == 1 && send.len() == 1);
56
57
if send[0] == PortState::Done {
58
recv[0] = PortState::Done;
59
self.last_length = 0;
60
self.last.take();
61
} else if recv[0] == PortState::Done {
62
if self.last.is_some() {
63
send[0] = PortState::Ready;
64
} else {
65
send[0] = PortState::Done;
66
}
67
} else {
68
recv.swap_with_slice(send);
69
}
70
71
Ok(())
72
}
73
74
fn spawn<'env, 's>(
75
&'env mut self,
76
scope: &'s TaskScope<'s, 'env>,
77
recv_ports: &mut [Option<RecvPort<'_>>],
78
send_ports: &mut [Option<SendPort<'_>>],
79
_state: &'s StreamingExecutionState,
80
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
81
) {
82
assert_eq!(recv_ports.len(), 1);
83
assert_eq!(send_ports.len(), 1);
84
85
let recv = recv_ports[0].take();
86
let mut send = send_ports[0].take().unwrap().serial();
87
88
let fields = vec![
89
Field::new(PlSmallStr::from_static(RLE_LENGTH_COLUMN_NAME), IDX_DTYPE),
90
Field::new(
91
PlSmallStr::from_static(RLE_VALUE_COLUMN_NAME),
92
self.dtype.clone(),
93
),
94
];
95
let output_dtype = DataType::Struct(fields.clone());
96
97
match recv {
98
None => {
99
// This happens when we have received out last morsel and we need to return one
100
// more value.
101
let last = self.last.take().unwrap();
102
if self.last_length > 0 {
103
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
104
let column = Scalar::new(
105
output_dtype,
106
AnyValue::StructOwned(Box::new((
107
vec![AnyValue::from(self.last_length), last],
108
fields,
109
))),
110
)
111
.into_column(self.name.clone());
112
113
let df = DataFrame::new(vec![column]).unwrap();
114
_ = send
115
.send(Morsel::new(df, self.seq.successor(), SourceToken::new()))
116
.await;
117
118
self.last_length = 0;
119
Ok(())
120
}));
121
}
122
},
123
124
Some(recv) => {
125
let mut recv = recv.serial();
126
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
127
let mut idxs = Vec::new();
128
let mut lengths = Vec::new();
129
while let Ok(mut m) = recv.recv().await {
130
self.seq = m.seq();
131
if m.df().height() == 0 {
132
continue;
133
}
134
135
assert_eq!(m.df().width(), 1);
136
let column = &m.df()[0];
137
138
lengths.clear();
139
polars_ops::series::rle_lengths(column, &mut lengths)?;
140
141
let mut new_first_is_last = false;
142
if let Some(last) = &self.last {
143
let fst = Scalar::new(
144
self.dtype.clone(),
145
column.get(0).unwrap().into_static(),
146
);
147
let last = Scalar::new(self.dtype.clone(), last.clone());
148
new_first_is_last = fst == last;
149
}
150
151
// If we have a morsel that is all the same value and we already know that
152
// value. Just add it to the length and continue.
153
if lengths.len() == 1 && new_first_is_last {
154
self.last_length += lengths[0];
155
continue;
156
}
157
158
let mut values = SeriesBuilder::new(self.dtype.clone());
159
values.reserve(lengths.len());
160
161
// Create the gather indices.
162
idxs.clear();
163
idxs.reserve(lengths.len() - 1);
164
let mut idx = 0;
165
for l in &lengths[0..lengths.len() - 1] {
166
idxs.push(idx);
167
idx += *l;
168
}
169
170
// Update the lengths to match what is being gathered and with the last
171
// element.
172
if new_first_is_last || self.last.is_none() {
173
lengths[0] += self.last_length;
174
self.last_length = lengths.pop().unwrap();
175
} else {
176
let mut prev = self.last_length;
177
for l in lengths.iter_mut() {
178
std::mem::swap(l, &mut prev);
179
}
180
self.last_length = prev;
181
}
182
let old_last = self
183
.last
184
.replace(column.get(column.len() - 1).unwrap().into_static());
185
186
// If we have nothing to return, just continue.
187
if lengths.is_empty() {
188
continue;
189
}
190
191
// If the morsel starts with a new value. We need to make sure to push it
192
// into the output values.
193
if !new_first_is_last && let Some(last) = old_last {
194
values.push_any_value(last);
195
}
196
197
// Actually gather the remaining values.
198
unsafe {
199
values.gather_extend(
200
column.as_materialized_series(),
201
&idxs,
202
ShareStrategy::Always,
203
)
204
};
205
206
let lengths = Series::new(
207
PlSmallStr::from_static(RLE_LENGTH_COLUMN_NAME),
208
std::mem::take(&mut lengths),
209
);
210
let series = values.freeze(PlSmallStr::from_static(RLE_VALUE_COLUMN_NAME));
211
212
let rle_struct = StructChunked::from_series(
213
self.name.clone(),
214
lengths.len(),
215
[&lengths, &series].into_iter(),
216
)
217
.unwrap();
218
*m.df_mut() = DataFrame::new(vec![rle_struct.into_column()]).unwrap();
219
220
if send.send(m).await.is_err() {
221
break;
222
}
223
}
224
Ok(())
225
}));
226
},
227
}
228
}
229
}
230
231