Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-io/src/path_utils/hugging_face.rs
7884 views
1
// Hugging Face path resolution support
2
3
use std::borrow::Cow;
4
5
use polars_error::{PolarsResult, polars_bail, to_compute_err};
6
use polars_utils::plpath::PlPath;
7
8
use crate::cloud::{
9
CloudConfig, CloudOptions, Matcher, USER_AGENT, extract_prefix_expansion,
10
try_build_http_header_map_from_items_slice,
11
};
12
use crate::path_utils::HiveIdxTracker;
13
use crate::pl_async::with_concurrency_budget;
14
use crate::utils::{URL_ENCODE_CHARSET, decode_json_response};
15
16
/// Percent-encoding character set for HF Hub paths.
17
///
18
/// This is URL_ENCODE_CHARSET with slashes preserved - by not encoding slashes,
19
/// the API request will be counted under a higher "resolvers" ratelimit of (3000/5min)
20
/// compared to the default "pages" limit of (100/5min limit).
21
///
22
/// ref <https://github.com/pola-rs/polars/issues/25389>
23
const HF_PATH_ENCODE_CHARSET: &percent_encoding::AsciiSet = &URL_ENCODE_CHARSET.remove(b'/');
24
25
#[derive(Debug, PartialEq)]
26
struct HFPathParts {
27
bucket: String,
28
repository: String,
29
revision: String,
30
/// Path relative to the repository root.
31
path: String,
32
}
33
34
struct HFRepoLocation {
35
api_base_path: String,
36
download_base_path: String,
37
}
38
39
impl HFRepoLocation {
40
fn new(bucket: &str, repository: &str, revision: &str) -> Self {
41
// * Don't percent-encode bucket/repository - they are path segments where
42
// slashes are separators. E.g. "HuggingFaceFW/fineweb-2" must stay as-is.
43
// * DO encode revision - slashes in revisions like "refs/convert/parquet"
44
// are part of the revision name, not path separators.
45
// See: https://github.com/pola-rs/polars/issues/25389
46
let encoded_revision =
47
percent_encoding::percent_encode(revision.as_bytes(), URL_ENCODE_CHARSET);
48
let api_base_path = format!(
49
"https://huggingface.co/api/{}/{}/tree/{}/",
50
bucket, repository, encoded_revision
51
);
52
let download_base_path = format!(
53
"https://huggingface.co/{}/{}/resolve/{}/",
54
bucket, repository, encoded_revision
55
);
56
57
Self {
58
api_base_path,
59
download_base_path,
60
}
61
}
62
63
fn get_file_uri(&self, rel_path: &str) -> String {
64
format!(
65
"{}{}",
66
self.download_base_path,
67
percent_encoding::percent_encode(rel_path.as_bytes(), HF_PATH_ENCODE_CHARSET)
68
)
69
}
70
71
fn get_api_uri(&self, rel_path: &str) -> String {
72
format!(
73
"{}{}",
74
self.api_base_path,
75
percent_encoding::percent_encode(rel_path.as_bytes(), HF_PATH_ENCODE_CHARSET)
76
)
77
}
78
}
79
80
impl HFPathParts {
81
/// Extracts path components from a hugging face path:
82
/// `hf:// [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}`
83
fn try_from_uri(uri: &str) -> PolarsResult<Self> {
84
let Some(this) = (|| {
85
// hf:// [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}
86
// !>
87
if !uri.starts_with("hf://") {
88
return None;
89
}
90
let uri = &uri[5..];
91
92
// [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}
93
// ^-----------------^ !>
94
let i = memchr::memchr(b'/', uri.as_bytes())?;
95
let bucket = uri.get(..i)?.to_string();
96
let uri = uri.get(1 + i..)?;
97
98
// {username} / {reponame} @ {revision} / {path from root}
99
// ^----------------------------------^ !>
100
let i = memchr::memchr(b'/', uri.as_bytes())?;
101
let i = {
102
// Also handle if they just give the repository, i.e.:
103
// hf:// [datasets | spaces] / {username} / {reponame} @ {revision}
104
let uri = uri.get(1 + i..)?;
105
if uri.is_empty() {
106
return None;
107
}
108
1 + i + memchr::memchr(b'/', uri.as_bytes()).unwrap_or(uri.len())
109
};
110
let repository = uri.get(..i)?;
111
let uri = uri.get(1 + i..).unwrap_or("");
112
113
let (repository, revision) =
114
if let Some(i) = memchr::memchr(b'@', repository.as_bytes()) {
115
(repository[..i].to_string(), repository[1 + i..].to_string())
116
} else {
117
// No @revision in uri, default to `main`
118
(repository.to_string(), "main".to_string())
119
};
120
121
// {path from root}
122
// ^--------------^
123
let path = uri.to_string();
124
125
Some(HFPathParts {
126
bucket,
127
repository,
128
revision,
129
path,
130
})
131
})() else {
132
polars_bail!(ComputeError: "invalid Hugging Face path: {}", uri);
133
};
134
135
const BUCKETS: [&str; 2] = ["datasets", "spaces"];
136
if !BUCKETS.contains(&this.bucket.as_str()) {
137
polars_bail!(ComputeError: "hugging face uri bucket must be one of {:?}, got {} instead.", BUCKETS, this.bucket);
138
}
139
140
Ok(this)
141
}
142
}
143
144
#[derive(Debug, serde::Deserialize)]
145
struct HFAPIResponse {
146
#[serde(rename = "type")]
147
type_: String,
148
path: String,
149
size: u64,
150
}
151
152
impl HFAPIResponse {
153
fn is_file(&self) -> bool {
154
self.type_ == "file"
155
}
156
}
157
158
/// API response is paginated with a `link` header.
159
/// * https://huggingface.co/docs/hub/en/api#get-apidatasets
160
/// * https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api?apiVersion=2022-11-28#using-link-headers
161
struct GetPages<'a> {
162
client: &'a reqwest::Client,
163
uri: Option<String>,
164
}
165
166
impl GetPages<'_> {
167
async fn next(&mut self) -> Option<PolarsResult<bytes::Bytes>> {
168
let uri = self.uri.take()?;
169
170
Some(
171
async {
172
let resp = with_concurrency_budget(1, || async {
173
self.client.get(uri).send().await.map_err(to_compute_err)
174
})
175
.await?;
176
177
self.uri = resp
178
.headers()
179
.get("link")
180
.and_then(|x| Self::find_link(x.as_bytes(), "next".as_bytes()))
181
.transpose()?;
182
183
let resp_bytes = resp.bytes().await.map_err(to_compute_err)?;
184
185
Ok(resp_bytes)
186
}
187
.await,
188
)
189
}
190
191
fn find_link(mut link: &[u8], rel: &[u8]) -> Option<PolarsResult<String>> {
192
// "<https://...>; rel=\"next\", <https://...>; rel=\"last\""
193
while !link.is_empty() {
194
let i = memchr::memchr(b'<', link)?;
195
link = link.get(1 + i..)?;
196
let i = memchr::memchr(b'>', link)?;
197
let uri = &link[..i];
198
link = link.get(1 + i..)?;
199
200
while !link.starts_with("rel=\"".as_bytes()) {
201
link = link.get(1..)?
202
}
203
204
// rel="next"
205
link = link.get(5..)?;
206
let i = memchr::memchr(b'"', link)?;
207
208
if &link[..i] == rel {
209
return Some(
210
std::str::from_utf8(uri)
211
.map_err(to_compute_err)
212
.map(ToString::to_string),
213
);
214
}
215
}
216
217
None
218
}
219
}
220
221
pub(super) async fn expand_paths_hf(
222
paths: &[PlPath],
223
check_directory_level: bool,
224
cloud_options: &Option<CloudOptions>,
225
glob: bool,
226
) -> PolarsResult<(usize, Vec<PlPath>)> {
227
assert!(!paths.is_empty());
228
229
let client = reqwest::ClientBuilder::new()
230
.user_agent(USER_AGENT)
231
.http1_only()
232
.https_only(true);
233
234
let client = if let Some(CloudOptions {
235
config: Some(CloudConfig::Http { headers }),
236
..
237
}) = cloud_options
238
{
239
client.default_headers(try_build_http_header_map_from_items_slice(
240
headers.as_slice(),
241
)?)
242
} else {
243
client
244
};
245
246
let client = &client.build().unwrap();
247
248
let mut out_paths = vec![];
249
let mut hive_idx_tracker = HiveIdxTracker {
250
idx: usize::MAX,
251
paths,
252
check_directory_level,
253
};
254
255
for (path_idx, path) in paths.iter().enumerate() {
256
let path_parts = &HFPathParts::try_from_uri(path.to_str())?;
257
let repo_location = &HFRepoLocation::new(
258
&path_parts.bucket,
259
&path_parts.repository,
260
&path_parts.revision,
261
);
262
let rel_path = path_parts.path.as_str();
263
264
let (prefix, expansion) = if glob {
265
extract_prefix_expansion(rel_path)?
266
} else {
267
(Cow::Owned(path_parts.path.clone()), None)
268
};
269
let expansion_matcher = &if expansion.is_some() {
270
Some(Matcher::new(prefix.to_string(), expansion.as_deref())?)
271
} else {
272
None
273
};
274
275
let file_uri = repo_location.get_file_uri(rel_path);
276
277
if !path_parts.path.ends_with("/") && expansion.is_none() {
278
// Confirm that this is a file using a HEAD request.
279
if with_concurrency_budget(1, || async {
280
client.head(&file_uri).send().await.map_err(to_compute_err)
281
})
282
.await?
283
.status()
284
== 200
285
{
286
hive_idx_tracker.update(0, path_idx)?;
287
out_paths.push(PlPath::from_string(file_uri));
288
continue;
289
}
290
}
291
292
hive_idx_tracker.update(file_uri.len(), path_idx)?;
293
294
let uri = format!("{}?recursive=true", repo_location.get_api_uri(&prefix));
295
let mut gp = GetPages {
296
uri: Some(uri),
297
client,
298
};
299
300
while let Some(bytes) = gp.next().await {
301
let bytes = bytes?;
302
let response: Vec<HFAPIResponse> = decode_json_response(bytes.as_ref())?;
303
304
for entry in response {
305
// Only include files with size > 0
306
if entry.is_file() && entry.size > 0 {
307
// If we have a glob pattern, filter by it; otherwise include all files
308
let matches = if let Some(matcher) = expansion_matcher {
309
matcher.is_matching(entry.path.as_str())
310
} else {
311
true
312
};
313
314
if matches {
315
out_paths
316
.push(PlPath::from_string(repo_location.get_file_uri(&entry.path)));
317
}
318
}
319
}
320
}
321
}
322
323
Ok((hive_idx_tracker.idx, out_paths))
324
}
325
326
mod tests {
327
328
#[test]
329
fn test_hf_path_from_uri() {
330
use super::HFPathParts;
331
332
let uri = "hf://datasets/pola-rs/polars/README.md";
333
let expect = HFPathParts {
334
bucket: "datasets".into(),
335
repository: "pola-rs/polars".into(),
336
revision: "main".into(),
337
path: "README.md".into(),
338
};
339
340
assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
341
342
let uri = "hf://spaces/pola-rs/polars@~parquet/";
343
let expect = HFPathParts {
344
bucket: "spaces".into(),
345
repository: "pola-rs/polars".into(),
346
revision: "~parquet".into(),
347
path: "".into(),
348
};
349
350
assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
351
352
let uri = "hf://spaces/pola-rs/polars@~parquet";
353
let expect = HFPathParts {
354
bucket: "spaces".into(),
355
repository: "pola-rs/polars".into(),
356
revision: "~parquet".into(),
357
path: "".into(),
358
};
359
360
assert_eq!(HFPathParts::try_from_uri(uri).unwrap(), expect);
361
362
for uri in [
363
"://",
364
"s3://",
365
"https://",
366
"hf://",
367
"hf:///",
368
"hf:////",
369
"hf://datasets/a",
370
"hf://datasets/a/",
371
"hf://bucket/a/b/c", // Invalid bucket name
372
] {
373
let out = HFPathParts::try_from_uri(uri);
374
if out.is_err() {
375
continue;
376
}
377
panic!("expected err result for uri {uri} instead of {out:?}");
378
}
379
}
380
381
#[test]
382
fn test_get_pages_find_next_link() {
383
use super::GetPages;
384
let link = r#"<https://api.github.com/repositories/263727855/issues?page=3>; rel="next", <https://api.github.com/repositories/263727855/issues?page=7>; rel="last""#.as_bytes();
385
386
assert_eq!(
387
GetPages::find_link(link, "next".as_bytes()).map(Result::unwrap),
388
Some("https://api.github.com/repositories/263727855/issues?page=3".into()),
389
);
390
391
assert_eq!(
392
GetPages::find_link(link, "last".as_bytes()).map(Result::unwrap),
393
Some("https://api.github.com/repositories/263727855/issues?page=7".into()),
394
);
395
396
assert_eq!(
397
GetPages::find_link(link, "non-existent".as_bytes()).map(Result::unwrap),
398
None,
399
);
400
}
401
402
#[test]
403
fn test_hf_url_encoding() {
404
// Verify URLs preserve slashes (don't encode as %2F) but encode special chars.
405
// Slashes must remain for correct rate limit classification by HF Hub.
406
// Special chars (spaces, colons) must be encoded for file downloads to work.
407
// See: https://github.com/pola-rs/polars/issues/25389
408
use super::HFRepoLocation;
409
410
let loc = HFRepoLocation::new("datasets", "HuggingFaceFW/fineweb-2", "main");
411
412
// Check base paths don't encode slashes
413
assert_eq!(
414
loc.api_base_path,
415
"https://huggingface.co/api/datasets/HuggingFaceFW/fineweb-2/tree/main/"
416
);
417
assert_eq!(
418
loc.download_base_path,
419
"https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/"
420
);
421
422
// Check file URIs preserve slashes in paths
423
let file_uri = loc.get_file_uri("data/aai_Latn/train/000_00000.parquet");
424
assert_eq!(
425
file_uri,
426
"https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/data/aai_Latn/train/000_00000.parquet"
427
);
428
429
// Check that special characters ARE encoded (spaces -> %20, colons -> %3A)
430
// This is needed for hive-partitioned paths like "date2=2023-01-01 00:00:00.000000"
431
let file_uri = loc.get_file_uri(
432
"hive_dates/date1=2024-01-01/date2=2023-01-01 00:00:00.000000/00000000.parquet",
433
);
434
assert_eq!(
435
file_uri,
436
"https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/hive_dates/date1%3D2024-01-01/date2%3D2023-01-01%2000%3A00%3A00.000000/00000000.parquet"
437
);
438
439
// Check that brackets are encoded ([ -> %5B, ] -> %5D)
440
let file_uri = loc.get_file_uri("special-chars/[*.parquet");
441
assert_eq!(
442
file_uri,
443
"https://huggingface.co/datasets/HuggingFaceFW/fineweb-2/resolve/main/special-chars/%5B%2A.parquet"
444
);
445
446
// Check that revision slashes ARE encoded (they're part of the revision name)
447
// e.g. "refs/convert/parquet" -> "refs%2Fconvert%2Fparquet"
448
let loc = HFRepoLocation::new("datasets", "user/repo", "refs/convert/parquet");
449
assert_eq!(
450
loc.api_base_path,
451
"https://huggingface.co/api/datasets/user/repo/tree/refs%2Fconvert%2Fparquet/"
452
);
453
assert_eq!(
454
loc.download_base_path,
455
"https://huggingface.co/datasets/user/repo/resolve/refs%2Fconvert%2Fparquet/"
456
);
457
}
458
}
459
460