Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-io/src/pl_async.rs
6939 views
1
use std::error::Error;
2
use std::future::Future;
3
use std::ops::Deref;
4
use std::sync::LazyLock;
5
6
use polars_core::POOL;
7
use polars_core::config::{self, verbose};
8
use polars_utils::relaxed_cell::RelaxedCell;
9
use tokio::runtime::{Builder, Runtime};
10
use tokio::sync::Semaphore;
11
12
static CONCURRENCY_BUDGET: std::sync::OnceLock<(Semaphore, u32)> = std::sync::OnceLock::new();
13
pub(super) const MAX_BUDGET_PER_REQUEST: usize = 10;
14
15
/// Used to determine chunks when splitting large ranges, or combining small
16
/// ranges.
17
static DOWNLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
18
let v: usize = std::env::var("POLARS_DOWNLOAD_CHUNK_SIZE")
19
.as_deref()
20
.map(|x| x.parse().expect("integer"))
21
.unwrap_or(64 * 1024 * 1024);
22
23
if config::verbose() {
24
eprintln!("async download_chunk_size: {v}")
25
}
26
27
v
28
});
29
30
pub(super) fn get_download_chunk_size() -> usize {
31
*DOWNLOAD_CHUNK_SIZE
32
}
33
34
static UPLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
35
let v: usize = std::env::var("POLARS_UPLOAD_CHUNK_SIZE")
36
.as_deref()
37
.map(|x| x.parse().expect("integer"))
38
.unwrap_or(64 * 1024 * 1024);
39
40
if config::verbose() {
41
eprintln!("async upload_chunk_size: {v}")
42
}
43
44
v
45
});
46
47
pub(super) fn get_upload_chunk_size() -> usize {
48
*UPLOAD_CHUNK_SIZE
49
}
50
51
pub trait GetSize {
52
fn size(&self) -> u64;
53
}
54
55
impl GetSize for bytes::Bytes {
56
fn size(&self) -> u64 {
57
self.len() as u64
58
}
59
}
60
61
impl<T: GetSize> GetSize for Vec<T> {
62
fn size(&self) -> u64 {
63
self.iter().map(|v| v.size()).sum()
64
}
65
}
66
67
impl<T: GetSize, E: Error> GetSize for Result<T, E> {
68
fn size(&self) -> u64 {
69
match self {
70
Ok(v) => v.size(),
71
Err(_) => 0,
72
}
73
}
74
}
75
76
#[cfg(feature = "cloud")]
77
pub(crate) struct Size(u64);
78
79
#[cfg(feature = "cloud")]
80
impl GetSize for Size {
81
fn size(&self) -> u64 {
82
self.0
83
}
84
}
85
#[cfg(feature = "cloud")]
86
impl From<u64> for Size {
87
fn from(value: u64) -> Self {
88
Self(value)
89
}
90
}
91
92
enum Optimization {
93
Step,
94
Accept,
95
Finished,
96
}
97
98
struct SemaphoreTuner {
99
previous_download_speed: u64,
100
last_tune: std::time::Instant,
101
downloaded: RelaxedCell<u64>,
102
download_time: RelaxedCell<u64>,
103
opt_state: Optimization,
104
increments: u32,
105
}
106
107
impl SemaphoreTuner {
108
fn new() -> Self {
109
Self {
110
previous_download_speed: 0,
111
last_tune: std::time::Instant::now(),
112
downloaded: RelaxedCell::from(0),
113
download_time: RelaxedCell::from(0),
114
opt_state: Optimization::Step,
115
increments: 0,
116
}
117
}
118
fn should_tune(&self) -> bool {
119
match self.opt_state {
120
Optimization::Finished => false,
121
_ => self.last_tune.elapsed().as_millis() > 350,
122
}
123
}
124
125
fn add_stats(&self, downloaded_bytes: u64, download_time: u64) {
126
self.downloaded.fetch_add(downloaded_bytes);
127
self.download_time.fetch_add(download_time);
128
}
129
130
fn increment(&mut self, semaphore: &Semaphore) {
131
semaphore.add_permits(1);
132
self.increments += 1;
133
}
134
135
fn tune(&mut self, semaphore: &'static Semaphore) -> bool {
136
let bytes_downloaded = self.downloaded.load();
137
let time_elapsed = self.download_time.load();
138
let download_speed = bytes_downloaded
139
.checked_div(time_elapsed)
140
.unwrap_or_default();
141
142
let increased = download_speed > self.previous_download_speed;
143
self.previous_download_speed = download_speed;
144
match self.opt_state {
145
Optimization::Step => {
146
self.increment(semaphore);
147
self.opt_state = Optimization::Accept
148
},
149
Optimization::Accept => {
150
// Accept the step
151
if increased {
152
// Set new step
153
self.increment(semaphore);
154
// Keep accept state to check next iteration
155
}
156
// Decline the step
157
else {
158
self.opt_state = Optimization::Finished;
159
FINISHED_TUNING.store(true);
160
if verbose() {
161
eprintln!(
162
"concurrency tuner finished after adding {} steps",
163
self.increments
164
)
165
}
166
// Finished.
167
return true;
168
}
169
},
170
Optimization::Finished => {},
171
}
172
self.last_tune = std::time::Instant::now();
173
// Not finished.
174
false
175
}
176
}
177
static INCR: RelaxedCell<u64> = RelaxedCell::new_u64(0);
178
static FINISHED_TUNING: RelaxedCell<bool> = RelaxedCell::new_bool(false);
179
static PERMIT_STORE: std::sync::OnceLock<tokio::sync::RwLock<SemaphoreTuner>> =
180
std::sync::OnceLock::new();
181
182
fn get_semaphore() -> &'static (Semaphore, u32) {
183
CONCURRENCY_BUDGET.get_or_init(|| {
184
let permits = std::env::var("POLARS_CONCURRENCY_BUDGET")
185
.map(|s| {
186
let budget = s.parse::<usize>().expect("integer");
187
FINISHED_TUNING.store(true);
188
budget
189
})
190
.unwrap_or_else(|_| std::cmp::max(POOL.current_num_threads(), MAX_BUDGET_PER_REQUEST));
191
(Semaphore::new(permits), permits as u32)
192
})
193
}
194
195
pub(crate) fn get_concurrency_limit() -> u32 {
196
get_semaphore().1
197
}
198
199
pub async fn tune_with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
200
where
201
F: FnOnce() -> Fut,
202
Fut: Future,
203
Fut::Output: GetSize,
204
{
205
let (semaphore, initial_budget) = get_semaphore();
206
207
// This would never finish otherwise.
208
assert!(requested_budget <= *initial_budget);
209
210
// Keep permit around.
211
// On drop it is returned to the semaphore.
212
let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
213
214
let now = std::time::Instant::now();
215
let res = callable().await;
216
217
if FINISHED_TUNING.load() || res.size() == 0 {
218
return res;
219
}
220
221
let duration = now.elapsed().as_millis() as u64;
222
let permit_store = PERMIT_STORE.get_or_init(|| tokio::sync::RwLock::new(SemaphoreTuner::new()));
223
224
let Ok(tuner) = permit_store.try_read() else {
225
return res;
226
};
227
// Keep track of download speed
228
tuner.add_stats(res.size(), duration);
229
230
// We only tune every n ms
231
if !tuner.should_tune() {
232
return res;
233
}
234
// Drop the read tuner before trying to acquire a writer
235
drop(tuner);
236
237
// Reduce locking by letting only 1 in 5 tasks lock the tuner
238
if !INCR.fetch_add(1).is_multiple_of(5) {
239
return res;
240
}
241
// Never lock as we will deadlock. This can run under rayon
242
let Ok(mut tuner) = permit_store.try_write() else {
243
return res;
244
};
245
let finished = tuner.tune(semaphore);
246
if finished {
247
drop(_permit_acq);
248
// Undo the last step
249
let undo = semaphore.acquire().await.unwrap();
250
std::mem::forget(undo)
251
}
252
res
253
}
254
255
pub async fn with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
256
where
257
F: FnOnce() -> Fut,
258
Fut: Future,
259
{
260
let (semaphore, initial_budget) = get_semaphore();
261
262
// This would never finish otherwise.
263
assert!(requested_budget <= *initial_budget);
264
265
// Keep permit around.
266
// On drop it is returned to the semaphore.
267
let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
268
269
callable().await
270
}
271
272
pub struct RuntimeManager {
273
rt: Runtime,
274
}
275
276
impl RuntimeManager {
277
fn new() -> Self {
278
let n_threads = std::env::var("POLARS_ASYNC_THREAD_COUNT")
279
.map(|x| x.parse::<usize>().expect("integer"))
280
.unwrap_or(POOL.current_num_threads().clamp(1, 4));
281
282
if polars_core::config::verbose() {
283
eprintln!("async thread count: {n_threads}");
284
}
285
286
let rt = Builder::new_multi_thread()
287
.worker_threads(n_threads)
288
.enable_io()
289
.enable_time()
290
.build()
291
.unwrap();
292
293
Self { rt }
294
}
295
296
/// Forcibly blocks this thread to evaluate the given future. This can be
297
/// dangerous and lead to deadlocks if called re-entrantly on an async
298
/// worker thread as the entire thread pool can end up blocking, leading to
299
/// a deadlock. If you want to prevent this use block_on, which will panic
300
/// if called from an async thread.
301
pub fn block_in_place_on<F>(&self, future: F) -> F::Output
302
where
303
F: Future,
304
{
305
tokio::task::block_in_place(|| self.rt.block_on(future))
306
}
307
308
/// Blocks this thread to evaluate the given future. Panics if the current
309
/// thread is an async runtime worker thread.
310
pub fn block_on<F>(&self, future: F) -> F::Output
311
where
312
F: Future,
313
{
314
self.rt.block_on(future)
315
}
316
317
/// Spawns a future onto the Tokio runtime (see [`tokio::runtime::Runtime::spawn`]).
318
pub fn spawn<F>(&self, future: F) -> tokio::task::JoinHandle<F::Output>
319
where
320
F: Future + Send + 'static,
321
F::Output: Send + 'static,
322
{
323
self.rt.spawn(future)
324
}
325
326
// See [`tokio::runtime::Runtime::spawn_blocking`].
327
pub fn spawn_blocking<F, R>(&self, f: F) -> tokio::task::JoinHandle<R>
328
where
329
F: FnOnce() -> R + Send + 'static,
330
R: Send + 'static,
331
{
332
self.rt.spawn_blocking(f)
333
}
334
335
/// Run a task on the rayon threadpool. To avoid deadlocks, if the current thread is already a
336
/// rayon thread, the task is executed on the current thread after tokio's `block_in_place` is
337
/// used to spawn another thread to poll futures.
338
pub async fn spawn_rayon<F, O>(&self, func: F) -> O
339
where
340
F: FnOnce() -> O + Send + Sync + 'static,
341
O: Send + Sync + 'static,
342
{
343
if POOL.current_thread_index().is_some() {
344
// We are a rayon thread, so we can't use POOL.spawn as it would mean we spawn a task and block until
345
// another rayon thread executes it - we would deadlock if all rayon threads did this.
346
// Safety: The tokio runtime flavor is multi-threaded.
347
tokio::task::block_in_place(func)
348
} else {
349
let (tx, rx) = tokio::sync::oneshot::channel();
350
351
let func = move || {
352
let out = func();
353
// Don't unwrap send attempt - async task could be cancelled.
354
let _ = tx.send(out);
355
};
356
357
POOL.spawn(func);
358
359
rx.await.unwrap()
360
}
361
}
362
}
363
364
static RUNTIME: LazyLock<RuntimeManager> = LazyLock::new(RuntimeManager::new);
365
366
pub fn get_runtime() -> &'static RuntimeManager {
367
RUNTIME.deref()
368
}
369
370