Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-io/src/pl_async.rs
8424 views
1
use std::error::Error;
2
use std::future::Future;
3
use std::ops::Deref;
4
use std::sync::LazyLock;
5
6
use polars_buffer::Buffer;
7
use polars_core::POOL;
8
use polars_core::config::{self, verbose};
9
use polars_utils::relaxed_cell::RelaxedCell;
10
use tokio::runtime::{Builder, Runtime};
11
use tokio::sync::Semaphore;
12
13
static CONCURRENCY_BUDGET: std::sync::OnceLock<(Semaphore, u32)> = std::sync::OnceLock::new();
14
pub(super) const MAX_BUDGET_PER_REQUEST: usize = 10;
15
16
/// Used to determine chunks when splitting large ranges, or combining small
17
/// ranges.
18
static DOWNLOAD_CHUNK_SIZE: LazyLock<usize> = LazyLock::new(|| {
19
let v: usize = std::env::var("POLARS_DOWNLOAD_CHUNK_SIZE")
20
.as_deref()
21
.map(|x| x.parse().expect("integer"))
22
.unwrap_or(64 * 1024 * 1024);
23
24
if config::verbose() {
25
eprintln!("async download_chunk_size: {v}")
26
}
27
28
v
29
});
30
31
pub(super) fn get_download_chunk_size() -> usize {
32
*DOWNLOAD_CHUNK_SIZE
33
}
34
35
pub trait GetSize {
36
fn size(&self) -> u64;
37
}
38
39
impl GetSize for Buffer<u8> {
40
fn size(&self) -> u64 {
41
self.len() as u64
42
}
43
}
44
45
impl<T: GetSize> GetSize for Vec<T> {
46
fn size(&self) -> u64 {
47
self.iter().map(|v| v.size()).sum()
48
}
49
}
50
51
impl<T: GetSize, E: Error> GetSize for Result<T, E> {
52
fn size(&self) -> u64 {
53
match self {
54
Ok(v) => v.size(),
55
Err(_) => 0,
56
}
57
}
58
}
59
60
#[cfg(feature = "cloud")]
61
pub(crate) struct Size(u64);
62
63
#[cfg(feature = "cloud")]
64
impl GetSize for Size {
65
fn size(&self) -> u64 {
66
self.0
67
}
68
}
69
#[cfg(feature = "cloud")]
70
impl From<u64> for Size {
71
fn from(value: u64) -> Self {
72
Self(value)
73
}
74
}
75
76
enum Optimization {
77
Step,
78
Accept,
79
Finished,
80
}
81
82
struct SemaphoreTuner {
83
previous_download_speed: u64,
84
last_tune: std::time::Instant,
85
downloaded: RelaxedCell<u64>,
86
download_time: RelaxedCell<u64>,
87
opt_state: Optimization,
88
increments: u32,
89
}
90
91
impl SemaphoreTuner {
92
fn new() -> Self {
93
Self {
94
previous_download_speed: 0,
95
last_tune: std::time::Instant::now(),
96
downloaded: RelaxedCell::from(0),
97
download_time: RelaxedCell::from(0),
98
opt_state: Optimization::Step,
99
increments: 0,
100
}
101
}
102
fn should_tune(&self) -> bool {
103
match self.opt_state {
104
Optimization::Finished => false,
105
_ => self.last_tune.elapsed().as_millis() > 350,
106
}
107
}
108
109
fn add_stats(&self, downloaded_bytes: u64, download_time: u64) {
110
self.downloaded.fetch_add(downloaded_bytes);
111
self.download_time.fetch_add(download_time);
112
}
113
114
fn increment(&mut self, semaphore: &Semaphore) {
115
semaphore.add_permits(1);
116
self.increments += 1;
117
}
118
119
fn tune(&mut self, semaphore: &'static Semaphore) -> bool {
120
let bytes_downloaded = self.downloaded.load();
121
let time_elapsed = self.download_time.load();
122
let download_speed = bytes_downloaded
123
.checked_div(time_elapsed)
124
.unwrap_or_default();
125
126
let increased = download_speed > self.previous_download_speed;
127
self.previous_download_speed = download_speed;
128
match self.opt_state {
129
Optimization::Step => {
130
self.increment(semaphore);
131
self.opt_state = Optimization::Accept
132
},
133
Optimization::Accept => {
134
// Accept the step
135
if increased {
136
// Set new step
137
self.increment(semaphore);
138
// Keep accept state to check next iteration
139
}
140
// Decline the step
141
else {
142
self.opt_state = Optimization::Finished;
143
FINISHED_TUNING.store(true);
144
if verbose() {
145
eprintln!(
146
"concurrency tuner finished after adding {} steps",
147
self.increments
148
)
149
}
150
// Finished.
151
return true;
152
}
153
},
154
Optimization::Finished => {},
155
}
156
self.last_tune = std::time::Instant::now();
157
// Not finished.
158
false
159
}
160
}
161
static INCR: RelaxedCell<u64> = RelaxedCell::new_u64(0);
162
static FINISHED_TUNING: RelaxedCell<bool> = RelaxedCell::new_bool(false);
163
static PERMIT_STORE: std::sync::OnceLock<tokio::sync::RwLock<SemaphoreTuner>> =
164
std::sync::OnceLock::new();
165
166
fn get_semaphore() -> &'static (Semaphore, u32) {
167
CONCURRENCY_BUDGET.get_or_init(|| {
168
let permits = std::env::var("POLARS_CONCURRENCY_BUDGET")
169
.map(|s| {
170
let budget = s.parse::<usize>().expect("integer");
171
FINISHED_TUNING.store(true);
172
budget
173
})
174
.unwrap_or_else(|_| std::cmp::max(POOL.current_num_threads(), MAX_BUDGET_PER_REQUEST));
175
(Semaphore::new(permits), permits as u32)
176
})
177
}
178
179
pub(crate) fn get_concurrency_limit() -> u32 {
180
get_semaphore().1
181
}
182
183
pub async fn tune_with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
184
where
185
F: FnOnce() -> Fut,
186
Fut: Future,
187
Fut::Output: GetSize,
188
{
189
let (semaphore, initial_budget) = get_semaphore();
190
191
// This would never finish otherwise.
192
assert!(requested_budget <= *initial_budget);
193
194
// Keep permit around.
195
// On drop it is returned to the semaphore.
196
let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
197
198
let now = std::time::Instant::now();
199
let res = callable().await;
200
201
if FINISHED_TUNING.load() || res.size() == 0 {
202
return res;
203
}
204
205
let duration = now.elapsed().as_millis() as u64;
206
let permit_store = PERMIT_STORE.get_or_init(|| tokio::sync::RwLock::new(SemaphoreTuner::new()));
207
208
let Ok(tuner) = permit_store.try_read() else {
209
return res;
210
};
211
// Keep track of download speed
212
tuner.add_stats(res.size(), duration);
213
214
// We only tune every n ms
215
if !tuner.should_tune() {
216
return res;
217
}
218
// Drop the read tuner before trying to acquire a writer
219
drop(tuner);
220
221
// Reduce locking by letting only 1 in 5 tasks lock the tuner
222
if !INCR.fetch_add(1).is_multiple_of(5) {
223
return res;
224
}
225
// Never lock as we will deadlock. This can run under rayon
226
let Ok(mut tuner) = permit_store.try_write() else {
227
return res;
228
};
229
let finished = tuner.tune(semaphore);
230
if finished {
231
drop(_permit_acq);
232
// Undo the last step
233
let undo = semaphore.acquire().await.unwrap();
234
std::mem::forget(undo)
235
}
236
res
237
}
238
239
pub async fn with_concurrency_budget<F, Fut>(requested_budget: u32, callable: F) -> Fut::Output
240
where
241
F: FnOnce() -> Fut,
242
Fut: Future,
243
{
244
let (semaphore, initial_budget) = get_semaphore();
245
246
// This would never finish otherwise.
247
assert!(requested_budget <= *initial_budget);
248
249
// Keep permit around.
250
// On drop it is returned to the semaphore.
251
let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap();
252
253
callable().await
254
}
255
256
pub struct RuntimeManager {
257
rt: Runtime,
258
}
259
260
impl RuntimeManager {
261
fn new() -> Self {
262
let n_threads = std::env::var("POLARS_ASYNC_THREAD_COUNT")
263
.map(|x| x.parse::<usize>().expect("integer"))
264
.unwrap_or(usize::min(POOL.current_num_threads(), 32));
265
266
let max_blocking = std::env::var("POLARS_MAX_BLOCKING_THREAD_COUNT")
267
.map(|x| x.parse::<usize>().expect("integer"))
268
.unwrap_or(512);
269
270
if polars_core::config::verbose() {
271
eprintln!("async thread count: {n_threads}");
272
eprintln!("blocking thread count: {max_blocking}");
273
}
274
275
let rt = Builder::new_multi_thread()
276
.worker_threads(n_threads)
277
.max_blocking_threads(max_blocking)
278
.enable_io()
279
.enable_time()
280
.build()
281
.unwrap();
282
283
Self { rt }
284
}
285
286
/// Forcibly blocks this thread to evaluate the given future. This can be
287
/// dangerous and lead to deadlocks if called re-entrantly on an async
288
/// worker thread as the entire thread pool can end up blocking, leading to
289
/// a deadlock. If you want to prevent this use block_on, which will panic
290
/// if called from an async thread.
291
pub fn block_in_place_on<F>(&self, future: F) -> F::Output
292
where
293
F: Future,
294
{
295
tokio::task::block_in_place(|| self.rt.block_on(future))
296
}
297
298
/// Blocks this thread to evaluate the given future. Panics if the current
299
/// thread is an async runtime worker thread.
300
pub fn block_on<F>(&self, future: F) -> F::Output
301
where
302
F: Future,
303
{
304
self.rt.block_on(future)
305
}
306
307
/// Spawns a future onto the Tokio runtime (see [`tokio::runtime::Runtime::spawn`]).
308
pub fn spawn<F>(&self, future: F) -> tokio::task::JoinHandle<F::Output>
309
where
310
F: Future + Send + 'static,
311
F::Output: Send + 'static,
312
{
313
self.rt.spawn(future)
314
}
315
316
// See [`tokio::runtime::Runtime::spawn_blocking`].
317
pub fn spawn_blocking<F, R>(&self, f: F) -> tokio::task::JoinHandle<R>
318
where
319
F: FnOnce() -> R + Send + 'static,
320
R: Send + 'static,
321
{
322
self.rt.spawn_blocking(f)
323
}
324
325
/// Run a task on the rayon threadpool. To avoid deadlocks, if the current thread is already a
326
/// rayon thread, the task is executed on the current thread after tokio's `block_in_place` is
327
/// used to spawn another thread to poll futures.
328
pub async fn spawn_rayon<F, O>(&self, func: F) -> O
329
where
330
F: FnOnce() -> O + Send + Sync + 'static,
331
O: Send + Sync + 'static,
332
{
333
if POOL.current_thread_index().is_some() {
334
// We are a rayon thread, so we can't use POOL.spawn as it would mean we spawn a task and block until
335
// another rayon thread executes it - we would deadlock if all rayon threads did this.
336
// Safety: The tokio runtime flavor is multi-threaded.
337
tokio::task::block_in_place(func)
338
} else {
339
let (tx, rx) = tokio::sync::oneshot::channel();
340
341
let func = move || {
342
let out = func();
343
// Don't unwrap send attempt - async task could be cancelled.
344
let _ = tx.send(out);
345
};
346
347
POOL.spawn(func);
348
349
rx.await.unwrap()
350
}
351
}
352
}
353
354
static RUNTIME: LazyLock<RuntimeManager> = LazyLock::new(RuntimeManager::new);
355
356
pub fn get_runtime() -> &'static RuntimeManager {
357
RUNTIME.deref()
358
}
359
360