Path: blob/main/crates/polars-io/src/cloud/polars_object_store.rs
6939 views
use std::ops::Range;12use bytes::Bytes;3use futures::{StreamExt, TryStreamExt};4use hashbrown::hash_map::RawEntryMut;5use object_store::path::Path;6use object_store::{ObjectMeta, ObjectStore};7use polars_core::prelude::{InitHashMaps, PlHashMap};8use polars_error::{PolarsError, PolarsResult};9use polars_utils::mmap::MemSlice;10use tokio::io::{AsyncSeekExt, AsyncWriteExt};1112use crate::pl_async::{13self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size,14tune_with_concurrency_budget, with_concurrency_budget,15};1617mod inner {18use std::future::Future;19use std::sync::Arc;2021use object_store::ObjectStore;22use polars_core::config;23use polars_error::PolarsResult;24use polars_utils::relaxed_cell::RelaxedCell;2526use crate::cloud::PolarsObjectStoreBuilder;2728#[derive(Debug)]29struct Inner {30store: tokio::sync::Mutex<Arc<dyn ObjectStore>>,31builder: PolarsObjectStoreBuilder,32}3334/// Polars wrapper around [`ObjectStore`] functionality. This struct is cheaply cloneable.35#[derive(Clone, Debug)]36pub struct PolarsObjectStore {37inner: Arc<Inner>,38/// Avoid contending the Mutex `lock()` until the first re-build.39initial_store: std::sync::Arc<dyn ObjectStore>,40/// Used for interior mutability. Doesn't need to be shared with other threads so it's not41/// inside `Arc<>`.42rebuilt: RelaxedCell<bool>,43}4445impl PolarsObjectStore {46pub(crate) fn new_from_inner(47store: Arc<dyn ObjectStore>,48builder: PolarsObjectStoreBuilder,49) -> Self {50let initial_store = store.clone();51Self {52inner: Arc::new(Inner {53store: tokio::sync::Mutex::new(store),54builder,55}),56initial_store,57rebuilt: RelaxedCell::from(false),58}59}6061/// Gets the underlying [`ObjectStore`] implementation.62pub async fn to_dyn_object_store(&self) -> Arc<dyn ObjectStore> {63if !self.rebuilt.load() {64self.initial_store.clone()65} else {66self.inner.store.lock().await.clone()67}68}6970pub async fn rebuild_inner(71&self,72from_version: &Arc<dyn ObjectStore>,73) -> PolarsResult<Arc<dyn ObjectStore>> {74let mut current_store = self.inner.store.lock().await;7576// If this does not eq, then `inner` was already re-built by another thread.77if Arc::ptr_eq(&*current_store, from_version) {78*current_store =79self.inner80.builder81.clone()82.build_impl(true)83.await84.map_err(|e| {85e.wrap_msg(|e| format!("attempt to rebuild object store failed: {e}"))86})?;87}8889self.rebuilt.store(true);9091Ok((*current_store).clone())92}9394pub async fn try_exec_rebuild_on_err<Fn, Fut, O>(&self, mut func: Fn) -> PolarsResult<O>95where96Fn: FnMut(&Arc<dyn ObjectStore>) -> Fut,97Fut: Future<Output = PolarsResult<O>>,98{99let store = self.to_dyn_object_store().await;100101let out = func(&store).await;102103let orig_err = match out {104Ok(v) => return Ok(v),105Err(e) => e,106};107108if config::verbose() {109eprintln!(110"[PolarsObjectStore]: got error: {}, will attempt re-build",111&orig_err112);113}114115let store = self116.rebuild_inner(&store)117.await118.map_err(|e| e.wrap_msg(|e| format!("{e}; original error: {orig_err}")))?;119120func(&store).await.map_err(|e| {121if self.inner.builder.is_azure()122&& std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref()123!= Ok("1")124{125// Note: This error is intended for Python audiences. The logic for retrieving126// these keys exist only on the Python side.127e.wrap_msg(|e| {128format!(129"{e}; note: if you are using Python, consider setting \130POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \131and use the storage account keys from Azure CLI to authenticate"132)133})134} else {135e136}137})138}139}140}141142pub use inner::PolarsObjectStore;143144pub type ObjectStorePath = object_store::path::Path;145146impl PolarsObjectStore {147/// Returns a buffered stream that downloads concurrently up to the concurrency limit.148fn get_buffered_ranges_stream<'a, T: Iterator<Item = Range<usize>>>(149store: &'a dyn ObjectStore,150path: &'a Path,151ranges: T,152) -> impl StreamExt<Item = PolarsResult<Bytes>>153+ TryStreamExt<Ok = Bytes, Error = PolarsError, Item = PolarsResult<Bytes>>154+ use<'a, T> {155futures::stream::iter(ranges.map(move |range| async move {156if range.is_empty() {157return Ok(Bytes::new());158}159160let out = store161.get_range(path, range.start as u64..range.end as u64)162.await?;163Ok(out)164}))165// Add a limit locally as this gets run inside a single `tune_with_concurrency_budget`.166.buffered(get_concurrency_limit() as usize)167}168169pub async fn get_range(&self, path: &Path, range: Range<usize>) -> PolarsResult<Bytes> {170if range.is_empty() {171return Ok(Bytes::new());172}173174self.try_exec_rebuild_on_err(move |store| {175let range = range.clone();176let st = store.clone();177178async move {179let store = st;180let parts = split_range(range.clone());181182if parts.len() == 1 {183let out = tune_with_concurrency_budget(1, move || async move {184store185.get_range(path, range.start as u64..range.end as u64)186.await187})188.await?;189190Ok(out)191} else {192let parts = tune_with_concurrency_budget(193parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,194|| {195Self::get_buffered_ranges_stream(&store, path, parts)196.try_collect::<Vec<Bytes>>()197},198)199.await?;200201let mut combined = Vec::with_capacity(range.len());202203for part in parts {204combined.extend_from_slice(&part)205}206207assert_eq!(combined.len(), range.len());208209PolarsResult::Ok(Bytes::from(combined))210}211}212})213.await214}215216/// Fetch byte ranges into a HashMap keyed by the range start. This will mutably sort the217/// `ranges` slice for coalescing.218///219/// # Panics220/// Panics if the same range start is used by more than 1 range.221pub async fn get_ranges_sort(222&self,223path: &Path,224ranges: &mut [Range<usize>],225) -> PolarsResult<PlHashMap<usize, MemSlice>> {226if ranges.is_empty() {227return Ok(Default::default());228}229230ranges.sort_unstable_by_key(|x| x.start);231232let ranges_len = ranges.len();233let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip();234235self.try_exec_rebuild_on_err(|store| {236let st = store.clone();237238async {239let store = st;240let mut out = PlHashMap::with_capacity(ranges_len);241242let mut stream =243Self::get_buffered_ranges_stream(&store, path, merged_ranges.iter().cloned());244245tune_with_concurrency_budget(246merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,247|| async {248let mut len = 0;249let mut current_offset = 0;250let mut ends_iter = merged_ends.iter();251252let mut splitted_parts = vec![];253254while let Some(bytes) = stream.try_next().await? {255len += bytes.len();256let end = *ends_iter.next().unwrap();257258if end == 0 {259splitted_parts.push(bytes);260continue;261}262263let full_range = ranges[current_offset..end]264.iter()265.cloned()266.reduce(|l, r| l.start.min(r.start)..l.end.max(r.end))267.unwrap();268269let bytes = if splitted_parts.is_empty() {270bytes271} else {272let mut out = Vec::with_capacity(full_range.len());273274for x in splitted_parts.drain(..) {275out.extend_from_slice(&x);276}277278out.extend_from_slice(&bytes);279Bytes::from(out)280};281282assert_eq!(bytes.len(), full_range.len());283284let bytes = MemSlice::from_bytes(bytes);285286for range in &ranges[current_offset..end] {287let mem_slice = bytes.slice(288range.start - full_range.start..range.end - full_range.start,289);290291match out.raw_entry_mut().from_key(&range.start) {292RawEntryMut::Vacant(slot) => {293slot.insert(range.start, mem_slice);294},295RawEntryMut::Occupied(mut slot) => {296if slot.get_mut().len() < mem_slice.len() {297*slot.get_mut() = mem_slice;298}299},300}301}302303current_offset = end;304}305306assert!(splitted_parts.is_empty());307308PolarsResult::Ok(pl_async::Size::from(len as u64))309},310)311.await?;312313Ok(out)314}315})316.await317}318319pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> {320let opt_size = self.head(path).await.ok().map(|x| x.size);321322let initial_pos = file.stream_position().await?;323324self.try_exec_rebuild_on_err(|store| {325let st = store.clone();326327// Workaround for "can't move captured variable".328let file: &mut tokio::fs::File = unsafe { std::mem::transmute_copy(&file) };329330async {331file.set_len(initial_pos).await?; // Reset if this function was called again.332333let store = st;334let parts = opt_size335.map(|x| split_range(0..x as usize))336.filter(|x| x.len() > 1);337338if let Some(parts) = parts {339tune_with_concurrency_budget(340parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,341|| async {342let mut stream = Self::get_buffered_ranges_stream(&store, path, parts);343let mut len = 0;344while let Some(bytes) = stream.try_next().await? {345len += bytes.len();346file.write_all(&bytes).await?;347}348349assert_eq!(len, opt_size.unwrap() as usize);350351PolarsResult::Ok(pl_async::Size::from(len as u64))352},353)354.await?355} else {356tune_with_concurrency_budget(1, || async {357let mut stream = store.get(path).await?.into_stream();358359let mut len = 0;360while let Some(bytes) = stream.try_next().await? {361len += bytes.len();362file.write_all(&bytes).await?;363}364365PolarsResult::Ok(pl_async::Size::from(len as u64))366})367.await?368};369370// Dropping is delayed for tokio async files so we need to explicitly371// flush here (https://github.com/tokio-rs/tokio/issues/2307#issuecomment-596336451).372file.sync_all().await.map_err(PolarsError::from)?;373374Ok(())375}376})377.await378}379380/// Fetch the metadata of the parquet file, do not memoize it.381pub async fn head(&self, path: &Path) -> PolarsResult<ObjectMeta> {382self.try_exec_rebuild_on_err(|store| {383let st = store.clone();384385async {386with_concurrency_budget(1, || async {387let store = st;388let head_result = store.head(path).await;389390if head_result.is_err() {391// Pre-signed URLs forbid the HEAD method, but we can still retrieve the header392// information with a range 0-0 request.393let get_range_0_0_result = store394.get_opts(395path,396object_store::GetOptions {397range: Some((0..1).into()),398..Default::default()399},400)401.await;402403if let Ok(v) = get_range_0_0_result {404return Ok(v.meta);405}406}407408let out = head_result?;409410Ok(out)411})412.await413}414})415.await416}417}418419/// Splits a single range into multiple smaller ranges, which can be downloaded concurrently for420/// much higher throughput.421fn split_range(range: Range<usize>) -> impl ExactSizeIterator<Item = Range<usize>> {422let chunk_size = get_download_chunk_size();423424// Calculate n_parts such that we are as close as possible to the `chunk_size`.425let n_parts = [426(range.len().div_ceil(chunk_size)).max(1),427(range.len() / chunk_size).max(1),428]429.into_iter()430.min_by_key(|x| (range.len() / *x).abs_diff(chunk_size))431.unwrap();432433let chunk_size = (range.len() / n_parts).max(1);434435assert_eq!(n_parts, (range.len() / chunk_size).max(1));436let bytes_rem = range.len() % chunk_size;437438(0..n_parts).map(move |part_no| {439let (start, end) = if part_no == 0 {440// Download remainder length in the first chunk since it starts downloading first.441let end = range.start + chunk_size + bytes_rem;442let end = if end > range.end { range.end } else { end };443(range.start, end)444} else {445let start = bytes_rem + range.start + part_no * chunk_size;446(start, start + chunk_size)447};448449start..end450})451}452453/// Note: For optimal performance, `ranges` should be sorted. More generally,454/// ranges placed next to each other should also be close in range value.455///456/// # Returns457/// `[(range1, end1), (range2, end2)]`, where:458/// * `range1` contains bytes for the ranges from `ranges[0..end1]`459/// * `range2` contains bytes for the ranges from `ranges[end1..end2]`460/// * etc..461///462/// Note that if an end value is 0, it means the range is a splitted part and should be combined.463fn merge_ranges(ranges: &[Range<usize>]) -> impl Iterator<Item = (Range<usize>, usize)> + '_ {464let chunk_size = get_download_chunk_size();465466let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone);467// Number of fetched bytes excluding excess.468let mut current_n_bytes = current_merged_range.len();469470(0..ranges.len())471.filter_map(move |current_idx| {472let current_idx = 1 + current_idx;473474if current_idx == ranges.len() {475// No more items - flush current state.476Some((current_merged_range.clone(), current_idx))477} else {478let range = ranges[current_idx].clone();479480let new_merged = current_merged_range.start.min(range.start)481..current_merged_range.end.max(range.end);482483// E.g.:484// |--------|485// oo // range1486// oo // range2487// ^^^ // distance = 3, is_overlapping = false488// E.g.:489// |--------|490// ooooo // range1491// ooooo // range2492// ^^ // distance = 2, is_overlapping = true493let (distance, is_overlapping) = {494let l = current_merged_range.end.min(range.end);495let r = current_merged_range.start.max(range.start);496497(r.abs_diff(l), r < l)498};499500let should_merge = is_overlapping || {501let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size)502<= current_merged_range.len().abs_diff(chunk_size);503let gap_tolerance =504(current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024);505506leq_current_len_dist_to_chunk_size && distance <= gap_tolerance507};508509if should_merge {510// Merge to existing range511current_merged_range = new_merged;512current_n_bytes += if is_overlapping {513range.len() - distance514} else {515range.len()516};517None518} else {519let out = (current_merged_range.clone(), current_idx);520current_merged_range = range;521current_n_bytes = current_merged_range.len();522Some(out)523}524}525})526.flat_map(|x| {527// Split large individual ranges within the list of ranges.528let (range, end) = x;529let split = split_range(range);530let len = split.len();531532split533.enumerate()534.map(move |(i, range)| (range, if 1 + i == len { end } else { 0 }))535})536}537538#[cfg(test)]539mod tests {540541#[test]542fn test_split_range() {543use super::{get_download_chunk_size, split_range};544545let chunk_size = get_download_chunk_size();546547assert_eq!(chunk_size, 64 * 1024 * 1024);548549#[allow(clippy::single_range_in_vec_init)]550{551// Round-trip empty ranges.552assert_eq!(split_range(0..0).collect::<Vec<_>>(), [0..0]);553assert_eq!(split_range(3..3).collect::<Vec<_>>(), [3..3]);554}555556// Threshold to start splitting to 2 ranges557//558// n - chunk_size == chunk_size - n / 2559// n + n / 2 == 2 * chunk_size560// 3 * n == 4 * chunk_size561// n = 4 * chunk_size / 3562let n = 4 * chunk_size / 3;563564#[allow(clippy::single_range_in_vec_init)]565{566assert_eq!(split_range(0..n).collect::<Vec<_>>(), [0..89478485]);567}568569assert_eq!(570split_range(0..n + 1).collect::<Vec<_>>(),571[0..44739243, 44739243..89478486]572);573574// Threshold to start splitting to 3 ranges575//576// n / 2 - chunk_size == chunk_size - n / 3577// n / 2 + n / 3 == 2 * chunk_size578// 5 * n == 12 * chunk_size579// n == 12 * chunk_size / 5580let n = 12 * chunk_size / 5;581582assert_eq!(583split_range(0..n).collect::<Vec<_>>(),584[0..80530637, 80530637..161061273]585);586587assert_eq!(588split_range(0..n + 1).collect::<Vec<_>>(),589[0..53687092, 53687092..107374183, 107374183..161061274]590);591}592593#[test]594fn test_merge_ranges() {595use super::{get_download_chunk_size, merge_ranges};596597let chunk_size = get_download_chunk_size();598599assert_eq!(chunk_size, 64 * 1024 * 1024);600601// Round-trip empty slice602assert_eq!(merge_ranges(&[]).collect::<Vec<_>>(), []);603604// We have 1 tiny request followed by 1 huge request. They are combined as it reduces the605// `abs_diff()` to the `chunk_size`, but afterwards they are split to 2 evenly sized606// requests.607assert_eq!(608merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::<Vec<_>>(),609[(0..66584576, 0), (66584576..133169152, 2)]610);611612// <= 1MiB gap, merge613assert_eq!(614merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::<Vec<_>>(),615[(0..1048578, 2)]616);617618// > 1MiB gap, do not merge619assert_eq!(620merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::<Vec<_>>(),621[(0..1, 1), (1048578..1048579, 2)]622);623624// <= 12.5% gap, merge625assert_eq!(626merge_ranges(&[0..8, 10..11]).collect::<Vec<_>>(),627[(0..11, 2)]628);629630// <= 12.5% gap relative to RHS, merge631assert_eq!(632merge_ranges(&[0..1, 3..11]).collect::<Vec<_>>(),633[(0..11, 2)]634);635636// Overlapping range, merge637assert_eq!(638merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024])639.collect::<Vec<_>>(),640[(0..80 * 1024 * 1024, 2)]641);642}643}644645646