Path: blob/main/crates/polars-io/src/path_utils/hugging_face.rs
7884 views
// Hugging Face path resolution support12use std::borrow::Cow;34use polars_error::{PolarsResult, polars_bail, to_compute_err};5use polars_utils::plpath::PlPath;67use crate::cloud::{8CloudConfig, CloudOptions, Matcher, USER_AGENT, extract_prefix_expansion,9try_build_http_header_map_from_items_slice,10};11use crate::path_utils::HiveIdxTracker;12use crate::pl_async::with_concurrency_budget;13use crate::utils::{URL_ENCODE_CHARSET, decode_json_response};1415/// Percent-encoding character set for HF Hub paths.16///17/// This is URL_ENCODE_CHARSET with slashes preserved - by not encoding slashes,18/// the API request will be counted under a higher "resolvers" ratelimit of (3000/5min)19/// compared to the default "pages" limit of (100/5min limit).20///21/// ref <https://github.com/pola-rs/polars/issues/25389>22const HF_PATH_ENCODE_CHARSET: &percent_encoding::AsciiSet = &URL_ENCODE_CHARSET.remove(b'/');2324#[derive(Debug, PartialEq)]25struct HFPathParts {26bucket: String,27repository: String,28revision: String,29/// Path relative to the repository root.30path: String,31}3233struct HFRepoLocation {34api_base_path: String,35download_base_path: String,36}3738impl HFRepoLocation {39fn new(bucket: &str, repository: &str, revision: &str) -> Self {40// * Don't percent-encode bucket/repository - they are path segments where41// slashes are separators. E.g. "HuggingFaceFW/fineweb-2" must stay as-is.42// * DO encode revision - slashes in revisions like "refs/convert/parquet"43// are part of the revision name, not path separators.44// See: https://github.com/pola-rs/polars/issues/2538945let encoded_revision =46percent_encoding::percent_encode(revision.as_bytes(), URL_ENCODE_CHARSET);47let api_base_path = format!(48"https://huggingface.co/api/{}/{}/tree/{}/",49bucket, repository, encoded_revision50);51let download_base_path = format!(52"https://huggingface.co/{}/{}/resolve/{}/",53bucket, repository, encoded_revision54);5556Self {57api_base_path,58download_base_path,59}60}6162fn get_file_uri(&self, rel_path: &str) -> String {63format!(64"{}{}",65self.download_base_path,66percent_encoding::percent_encode(rel_path.as_bytes(), HF_PATH_ENCODE_CHARSET)67)68}6970fn get_api_uri(&self, rel_path: &str) -> String {71format!(72"{}{}",73self.api_base_path,74percent_encoding::percent_encode(rel_path.as_bytes(), HF_PATH_ENCODE_CHARSET)75)76}77}7879impl HFPathParts {80/// Extracts path components from a hugging face path:81/// `hf:// [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}`82fn try_from_uri(uri: &str) -> PolarsResult<Self> {83let Some(this) = (|| {84// hf:// [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}85// !>86if !uri.starts_with("hf://") {87return None;88}89let uri = &uri[5..];9091// [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}92// ^-----------------^ !>93let i = memchr::memchr(b'/', uri.as_bytes())?;94let bucket = uri.get(..i)?.to_string();95let uri = uri.get(1 + i..)?;9697// {username} / {reponame} @ {revision} / {path from root}98// ^----------------------------------^ !>99let i = memchr::memchr(b'/', uri.as_bytes())?;100let i = {101// Also handle if they just give the repository, i.e.:102// hf:// [datasets | spaces] / {username} / {reponame} @ {revision}103let uri = uri.get(1 + i..)?;104if uri.is_empty() {105return None;106}1071 + i + memchr::memchr(b'/', uri.as_bytes()).unwrap_or(uri.len())108};109let repository = uri.get(..i)?;110let uri = uri.get(1 + i..).unwrap_or("");111112let (repository, revision) =113if let Some(i) = memchr::memchr(b'@', repository.as_bytes()) {114(repository[..i].to_string(), repository[1 + i..].to_string())115} else {116// No @revision in uri, default to `main`117(repository.to_string(), "main".to_string())118};119120// {path from root}121// ^--------------^122let path = uri.to_string();123124Some(HFPathParts {125bucket,126repository,127revision,128path,129})130})() else {131polars_bail!(ComputeError: "invalid Hugging Face path: {}", uri);132};133134const BUCKETS: [&str; 2] = ["datasets", "spaces"];135if !BUCKETS.contains(&this.bucket.as_str()) {136polars_bail!(ComputeError: "hugging face uri bucket must be one of {:?}, got {} instead.", BUCKETS, this.bucket);137}138139Ok(this)140}141}142143#[derive(Debug, serde::Deserialize)]144struct HFAPIResponse {145#[serde(rename = "type")]146type_: String,147path: String,148size: u64,149}150151impl HFAPIResponse {152fn is_file(&self) -> bool {153self.type_ == "file"154}155}156157/// API response is paginated with a `link` header.158/// * https://huggingface.co/docs/hub/en/api#get-apidatasets159/// * https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api?apiVersion=2022-11-28#using-link-headers160struct GetPages<'a> {161client: &'a reqwest::Client,162uri: Option<String>,163}164165impl GetPages<'_> {166async fn next(&mut self) -> Option<PolarsResult<bytes::Bytes>> {167let uri = self.uri.take()?;168169Some(170async {171let resp = with_concurrency_budget(1, || async {172self.client.get(uri).send().await.map_err(to_compute_err)173})174.await?;175176self.uri = resp177.headers()178.get("link")179.and_then(|x| Self::find_link(x.as_bytes(), "next".as_bytes()))180.transpose()?;181182let resp_bytes = resp.bytes().await.map_err(to_compute_err)?;183184Ok(resp_bytes)185}186.await,187)188}189190fn find_link(mut link: &[u8], rel: &[u8]) -> Option<PolarsResult<String>> {191// "<https://...>; rel=\"next\", <https://...>; rel=\"last\""192while !link.is_empty() {193let i = memchr::memchr(b'<', link)?;194link = link.get(1 + i..)?;195let i = memchr::memchr(b'>', link)?;196let uri = &link[..i];197link = link.get(1 + i..)?;198199while !link.starts_with("rel=\"".as_bytes()) {200link = link.get(1..)?201}202203// rel="next"204link = link.get(5..)?;205let i = memchr::memchr(b'"', link)?;206207if &link[..i] == rel {208return Some(209std::str::from_utf8(uri)210.map_err(to_compute_err)211.map(ToString::to_string),212);213}214}215216None217}218}219220pub(super) async fn expand_paths_hf(221paths: &[PlPath],222check_directory_level: bool,223cloud_options: &Option<CloudOptions>,224glob: bool,225) -> PolarsResult<(usize, Vec<PlPath>)> {226assert!(!paths.is_empty());227228let client = reqwest::ClientBuilder::new()229.user_agent(USER_AGENT)230.http1_only()231.https_only(true);232233let client = if let Some(CloudOptions {234config: Some(CloudConfig::Http { headers }),235..236}) = cloud_options237{238client.default_headers(try_build_http_header_map_from_items_slice(239headers.as_slice(),240)?)241} else {242client243};244245let client = &client.build().unwrap();246247let mut out_paths = vec![];248let mut hive_idx_tracker = HiveIdxTracker {249idx: usize::MAX,250paths,251check_directory_level,252};253254for (path_idx, path) in paths.iter().enumerate() {255let path_parts = &HFPathParts::try_from_uri(path.to_str())?;256let repo_location = &HFRepoLocation::new(257&path_parts.bucket,258&path_parts.repository,259&path_parts.revision,260);261let rel_path = path_parts.path.as_str();262263let (prefix, expansion) = if glob {264extract_prefix_expansion(rel_path)?265} else {266(Cow::Owned(path_parts.path.clone()), None)267};268let expansion_matcher = &if expansion.is_some() {269Some(Matcher::new(prefix.to_string(), expansion.as_deref())?)270} else {271None272};273274let file_uri = repo_location.get_file_uri(rel_path);275276if !path_parts.path.ends_with("/") && expansion.is_none() {277// Confirm that this is a file using a HEAD request.278if with_concurrency_budget(1, || async {279client.head(&file_uri).send().await.map_err(to_compute_err)280})281.await?282.status()283== 200284{285hive_idx_tracker.update(0, path_idx)?;286out_paths.push(PlPath::from_string(file_uri));287continue;288}289}290291hive_idx_tracker.update(file_uri.len(), path_idx)?;292293let uri = format!("{}?recursive=true", repo_location.get_api_uri(&prefix));294let mut gp = GetPages {295uri: Some(uri),296client,297};298299while let Some(bytes) = gp.next().await {300let bytes = bytes?;301let response: Vec<HFAPIResponse> = decode_json_response(bytes.as_ref())?;302303for entry in response {304// Only include files with size > 0305if entry.is_file() && entry.size > 0 {306// If we have a glob pattern, filter by it; otherwise include all files307let matches = if let Some(matcher) = expansion_matcher {308matcher.is_matching(entry.path.as_str())309} else {310true311};312313if matches {314out_paths315.push(PlPath::from_string(repo_location.get_file_uri(&entry.path)));316}317}318}319}320}321322Ok((hive_idx_tracker.idx, out_paths))323}324325mod tests {326327#[test]328fn test_hf_path_from_uri() {329use super::HFPathParts;330331let uri = "hf://datasets/pola-rs/polars/README.md";332let expect = HFPathParts {333bucket: "datasets".into(),334repository: "pola-rs/polars".into(),335revision: "main".into(),336path: "README.md".into(),337};338339assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);340341let uri = "hf://spaces/pola-rs/polars@~parquet/";342let expect = HFPathParts {343bucket: "spaces".into(),344repository: "pola-rs/polars".into(),345revision: "~parquet".into(),346path: "".into(),347};348349assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);350351let uri = "hf://spaces/pola-rs/polars@~parquet";352let expect = HFPathParts {353bucket: "spaces".into(),354repository: "pola-rs/polars".into(),355revision: "~parquet".into(),356path: "".into(),357};358359assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);360361for uri in [362"://",363"s3://",364"https://",365"hf://",366"hf:///",367"hf:////",368"hf://datasets/a",369"hf://datasets/a/",370"hf://bucket/a/b/c", // Invalid bucket name371] {372let out = HFPathParts::try_from_uri(uri);373if out.is_err() {374continue;375}376panic!("expected err result for uri {uri} instead of {out:?}");377}378}379380#[test]381fn test_get_pages_find_next_link() {382use super::GetPages;383let link = r#"<https://api.github.com/repositories/263727855/issues?page=3>; rel="next", <https://api.github.com/repositories/263727855/issues?page=7>; rel="last""#.as_bytes();384385assert_eq!(386GetPages::find_link(link, "next".as_bytes()).map(Result::unwrap),387Some("https://api.github.com/repositories/263727855/issues?page=3".into()),388);389390assert_eq!(391GetPages::find_link(link, "last".as_bytes()).map(Result::unwrap),392Some("https://api.github.com/repositories/263727855/issues?page=7".into()),393);394395assert_eq!(396GetPages::find_link(link, "non-existent".as_bytes()).map(Result::unwrap),397None,398);399}400401#[test]402fn test_hf_url_encoding() {403// Verify URLs preserve slashes (don't encode as %2F) but encode special chars.404// Slashes must remain for correct rate limit classification by HF Hub.405// Special chars (spaces, colons) must be encoded for file downloads to work.406// See: https://github.com/pola-rs/polars/issues/25389407use super::HFRepoLocation;408409let loc = HFRepoLocation::new("datasets", "HuggingFaceFW/fineweb-2", "main");410411// Check base paths don't encode slashes412assert_eq!(413loc.api_base_path,414"https://huggingface.co/api/datasets/HuggingFaceFW/fineweb-2/tree/main/"415);416assert_eq!(417loc.download_base_path,418"https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/"419);420421// Check file URIs preserve slashes in paths422let file_uri = loc.get_file_uri("data/aai_Latn/train/000_00000.parquet");423assert_eq!(424file_uri,425"https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/data/aai_Latn/train/000_00000.parquet"426);427428// Check that special characters ARE encoded (spaces -> %20, colons -> %3A)429// This is needed for hive-partitioned paths like "date2=2023-01-01 00:00:00.000000"430let file_uri = loc.get_file_uri(431"hive_dates/date1=2024-01-01/date2=2023-01-01 00:00:00.000000/00000000.parquet",432);433assert_eq!(434file_uri,435"https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/hive_dates/date1%3D2024-01-01/date2%3D2023-01-01%2000%3A00%3A00.000000/00000000.parquet"436);437438// Check that brackets are encoded ([ -> %5B, ] -> %5D)439let file_uri = loc.get_file_uri("special-chars/[*.parquet");440assert_eq!(441file_uri,442"https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/special-chars/%5B%2A.parquet"443);444445// Check that revision slashes ARE encoded (they're part of the revision name)446// e.g. "refs/convert/parquet" -> "refs%2Fconvert%2Fparquet"447let loc = HFRepoLocation::new("datasets", "user/repo", "refs/convert/parquet");448assert_eq!(449loc.api_base_path,450"https://huggingface.co/api/datasets/user/repo/tree/refs%2Fconvert%2Fparquet/"451);452assert_eq!(453loc.download_base_path,454"https://huggingface.co/datasets/user/repo/resolve/refs%2Fconvert%2Fparquet/"455);456}457}458459460