Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
TheLastBen
GitHub Repository: TheLastBen/fast-stable-diffusion
Path: blob/main/Dreambooth/hub.py
540 views
1
# Copyright 2022 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
# http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""
15
Hub utilities: utilities related to download and cache models
16
"""
17
import json
18
import os
19
import re
20
import shutil
21
import sys
22
import tempfile
23
import traceback
24
import warnings
25
from pathlib import Path
26
from typing import Dict, List, Optional, Tuple, Union
27
from urllib.parse import urlparse
28
from uuid import uuid4
29
30
import huggingface_hub
31
import requests
32
from huggingface_hub import (
33
CommitOperationAdd,
34
HfFolder,
35
create_commit,
36
create_repo,
37
get_hf_file_metadata,
38
hf_hub_download,
39
hf_hub_url,
40
whoami,
41
)
42
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
43
from huggingface_hub.utils import (
44
EntryNotFoundError,
45
LocalEntryNotFoundError,
46
RepositoryNotFoundError,
47
RevisionNotFoundError,
48
hf_raise_for_status,
49
)
50
from requests.exceptions import HTTPError
51
from transformers.utils.logging import tqdm
52
53
from . import __version__, logging
54
from .generic import working_or_temp_dir
55
from .import_utils import (
56
ENV_VARS_TRUE_VALUES,
57
_tf_version,
58
_torch_version,
59
is_tf_available,
60
is_torch_available,
61
is_training_run_on_sagemaker,
62
)
63
64
65
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
66
67
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
68
69
70
def is_offline_mode():
71
return _is_offline_mode
72
73
74
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
75
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
76
# New default cache, shared with the Datasets library
77
hf_cache_home = os.path.expanduser(
78
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "/kaggle/working/.cache"), "huggingface"))
79
)
80
default_cache_path = os.path.join(hf_cache_home, "hub")
81
82
# Onetime move from the old location to the new one if no ENV variable has been set.
83
if (
84
os.path.isdir(old_default_cache_path)
85
and not os.path.isdir(default_cache_path)
86
and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ
87
and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ
88
and "TRANSFORMERS_CACHE" not in os.environ
89
):
90
logger.warning(
91
"In Transformers v4.0.0, the default path to cache downloaded models changed from"
92
" '~/.cache/torch/transformers' to '~/.cache/huggingface/transformers'. Since you don't seem to have"
93
" overridden and '~/.cache/torch/transformers' is a directory that exists, we're moving it to"
94
" '~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should"
95
" only see this message once."
96
)
97
shutil.move(old_default_cache_path, default_cache_path)
98
99
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
100
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
101
HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", PYTORCH_TRANSFORMERS_CACHE)
102
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", HUGGINGFACE_HUB_CACHE)
103
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
104
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
105
SESSION_ID = uuid4().hex
106
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES
107
108
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
109
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
110
111
_staging_mode = os.environ.get("HUGGINGFACE_CO_STAGING", "NO").upper() in ENV_VARS_TRUE_VALUES
112
_default_endpoint = "https://hub-ci.huggingface.co" if _staging_mode else "https://huggingface.co"
113
114
HUGGINGFACE_CO_RESOLVE_ENDPOINT = _default_endpoint
115
if os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", None) is not None:
116
warnings.warn(
117
"Using the environment variable `HUGGINGFACE_CO_RESOLVE_ENDPOINT` is deprecated and will be removed in "
118
"Transformers v5. Use `HF_ENDPOINT` instead.",
119
FutureWarning,
120
)
121
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", None)
122
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_RESOLVE_ENDPOINT)
123
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
124
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
125
126
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
127
_CACHED_NO_EXIST = object()
128
129
130
def is_remote_url(url_or_filename):
131
parsed = urlparse(url_or_filename)
132
return parsed.scheme in ("http", "https")
133
134
135
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
136
"""
137
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
138
etag, size_MB)`. Filenames in `cache_dir` are use to get the metadata for each model, only urls ending with *.bin*
139
are added.
140
141
Args:
142
cache_dir (`Union[str, Path]`, *optional*):
143
The cache directory to search for models within. Will default to the transformers cache if unset.
144
145
Returns:
146
List[Tuple]: List of tuples each with shape `(model_url, etag, size_MB)`
147
"""
148
if cache_dir is None:
149
cache_dir = TRANSFORMERS_CACHE
150
elif isinstance(cache_dir, Path):
151
cache_dir = str(cache_dir)
152
if not os.path.isdir(cache_dir):
153
return []
154
155
cached_models = []
156
for file in os.listdir(cache_dir):
157
if file.endswith(".json"):
158
meta_path = os.path.join(cache_dir, file)
159
with open(meta_path, encoding="utf-8") as meta_file:
160
metadata = json.load(meta_file)
161
url = metadata["url"]
162
etag = metadata["etag"]
163
if url.endswith(".bin"):
164
size_MB = os.path.getsize(meta_path.strip(".json")) / 1e6
165
cached_models.append((url, etag, size_MB))
166
167
return cached_models
168
169
170
def define_sagemaker_information():
171
try:
172
instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
173
dlc_container_used = instance_data["Image"]
174
dlc_tag = instance_data["Image"].split(":")[1]
175
except Exception:
176
dlc_container_used = None
177
dlc_tag = None
178
179
sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}"))
180
runs_distributed_training = True if "sagemaker_distributed_dataparallel_enabled" in sagemaker_params else False
181
account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None
182
183
sagemaker_object = {
184
"sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None),
185
"sm_region": os.getenv("AWS_REGION", None),
186
"sm_number_gpu": os.getenv("SM_NUM_GPUS", 0),
187
"sm_number_cpu": os.getenv("SM_NUM_CPUS", 0),
188
"sm_distributed_training": runs_distributed_training,
189
"sm_deep_learning_container": dlc_container_used,
190
"sm_deep_learning_container_tag": dlc_tag,
191
"sm_account_id": account_id,
192
}
193
return sagemaker_object
194
195
196
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
197
"""
198
Formats a user-agent string with basic info about a request.
199
"""
200
ua = f"transformers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
201
if is_torch_available():
202
ua += f"; torch/{_torch_version}"
203
if is_tf_available():
204
ua += f"; tensorflow/{_tf_version}"
205
if DISABLE_TELEMETRY:
206
return ua + "; telemetry/off"
207
if is_training_run_on_sagemaker():
208
ua += "; " + "; ".join(f"{k}/{v}" for k, v in define_sagemaker_information().items())
209
# CI will set this value to True
210
if os.environ.get("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
211
ua += "; is_ci/true"
212
if isinstance(user_agent, dict):
213
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
214
elif isinstance(user_agent, str):
215
ua += "; " + user_agent
216
return ua
217
218
219
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]):
220
"""
221
Extracts the commit hash from a resolved filename toward a cache file.
222
"""
223
if resolved_file is None or commit_hash is not None:
224
return commit_hash
225
resolved_file = str(Path(resolved_file).as_posix())
226
search = re.search(r"snapshots/([^/]+)/", resolved_file)
227
if search is None:
228
return None
229
commit_hash = search.groups()[0]
230
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
231
232
233
def try_to_load_from_cache(
234
repo_id: str,
235
filename: str,
236
cache_dir: Union[str, Path, None] = None,
237
revision: Optional[str] = None,
238
) -> Optional[str]:
239
"""
240
Explores the cache to return the latest cached file for a given revision if found.
241
242
This function will not raise any exception if the file in not cached.
243
244
Args:
245
cache_dir (`str` or `os.PathLike`):
246
The folder where the cached files lie.
247
repo_id (`str`):
248
The ID of the repo on huggingface.co.
249
filename (`str`):
250
The filename to look for inside `repo_id`.
251
revision (`str`, *optional*):
252
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
253
provided either.
254
255
Returns:
256
`Optional[str]` or `_CACHED_NO_EXIST`:
257
Will return `None` if the file was not cached. Otherwise:
258
- The exact path to the cached file if it's found in the cache
259
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
260
cached.
261
"""
262
if revision is None:
263
revision = "main"
264
265
if cache_dir is None:
266
cache_dir = TRANSFORMERS_CACHE
267
268
object_id = repo_id.replace("/", "--")
269
repo_cache = os.path.join(cache_dir, f"models--{object_id}")
270
if not os.path.isdir(repo_cache):
271
# No cache for this model
272
return None
273
for subfolder in ["refs", "snapshots"]:
274
if not os.path.isdir(os.path.join(repo_cache, subfolder)):
275
return None
276
277
# Resolve refs (for instance to convert main to the associated commit sha)
278
cached_refs = os.listdir(os.path.join(repo_cache, "refs"))
279
if revision in cached_refs:
280
with open(os.path.join(repo_cache, "refs", revision)) as f:
281
revision = f.read()
282
283
if os.path.isfile(os.path.join(repo_cache, ".no_exist", revision, filename)):
284
return _CACHED_NO_EXIST
285
286
cached_shas = os.listdir(os.path.join(repo_cache, "snapshots"))
287
if revision not in cached_shas:
288
# No cache for this revision and we won't try to return a random revision
289
return None
290
291
cached_file = os.path.join(repo_cache, "snapshots", revision, filename)
292
return cached_file if os.path.isfile(cached_file) else None
293
294
295
def cached_file(
296
path_or_repo_id: Union[str, os.PathLike],
297
filename: str,
298
cache_dir: Optional[Union[str, os.PathLike]] = None,
299
force_download: bool = False,
300
resume_download: bool = False,
301
proxies: Optional[Dict[str, str]] = None,
302
use_auth_token: Optional[Union[bool, str]] = None,
303
revision: Optional[str] = None,
304
local_files_only: bool = False,
305
subfolder: str = "",
306
user_agent: Optional[Union[str, Dict[str, str]]] = None,
307
_raise_exceptions_for_missing_entries: bool = True,
308
_raise_exceptions_for_connection_errors: bool = True,
309
_commit_hash: Optional[str] = None,
310
):
311
"""
312
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
313
314
Args:
315
path_or_repo_id (`str` or `os.PathLike`):
316
This can be either:
317
318
- a string, the *model id* of a model repo on huggingface.co.
319
- a path to a *directory* potentially containing the file.
320
filename (`str`):
321
The name of the file to locate in `path_or_repo`.
322
cache_dir (`str` or `os.PathLike`, *optional*):
323
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
324
cache should not be used.
325
force_download (`bool`, *optional*, defaults to `False`):
326
Whether or not to force to (re-)download the configuration files and override the cached versions if they
327
exist.
328
resume_download (`bool`, *optional*, defaults to `False`):
329
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
330
proxies (`Dict[str, str]`, *optional*):
331
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
332
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
333
use_auth_token (`str` or *bool*, *optional*):
334
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
335
when running `huggingface-cli login` (stored in `~/.huggingface`).
336
revision (`str`, *optional*, defaults to `"main"`):
337
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
338
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
339
identifier allowed by git.
340
local_files_only (`bool`, *optional*, defaults to `False`):
341
If `True`, will only try to load the tokenizer configuration from local files.
342
subfolder (`str`, *optional*, defaults to `""`):
343
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
344
specify the folder name here.
345
346
<Tip>
347
348
Passing `use_auth_token=True` is required when you want to use a private model.
349
350
</Tip>
351
352
Returns:
353
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo).
354
355
Examples:
356
357
```python
358
# Download a model weight from the Hub and cache it.
359
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
360
```"""
361
# Private arguments
362
# _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
363
# None.
364
# _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
365
# None.
366
# _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
367
# a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
368
if is_offline_mode() and not local_files_only:
369
logger.info("Offline mode: forcing local_files_only=True")
370
local_files_only = True
371
if subfolder is None:
372
subfolder = ""
373
374
path_or_repo_id = str(path_or_repo_id)
375
full_filename = os.path.join(subfolder, filename)
376
if os.path.isdir(path_or_repo_id):
377
resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
378
if not os.path.isfile(resolved_file):
379
if _raise_exceptions_for_missing_entries:
380
raise EnvironmentError(
381
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
382
f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
383
)
384
else:
385
return None
386
return resolved_file
387
388
if cache_dir is None:
389
cache_dir = TRANSFORMERS_CACHE
390
if isinstance(cache_dir, Path):
391
cache_dir = str(cache_dir)
392
393
if _commit_hash is not None:
394
# If the file is cached under that commit hash, we return it directly.
395
resolved_file = try_to_load_from_cache(
396
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash
397
)
398
if resolved_file is not None:
399
if resolved_file is not _CACHED_NO_EXIST:
400
return resolved_file
401
elif not _raise_exceptions_for_missing_entries:
402
return None
403
else:
404
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
405
406
user_agent = http_user_agent(user_agent)
407
try:
408
# Load from URL or cache if already cached
409
resolved_file = hf_hub_download(
410
path_or_repo_id,
411
filename,
412
subfolder=None if len(subfolder) == 0 else subfolder,
413
revision=revision,
414
cache_dir=cache_dir,
415
user_agent=user_agent,
416
force_download=force_download,
417
proxies=proxies,
418
resume_download=resume_download,
419
use_auth_token=use_auth_token,
420
local_files_only=local_files_only,
421
)
422
423
except RepositoryNotFoundError:
424
raise EnvironmentError(
425
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
426
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
427
"pass a token having permission to this repo with `use_auth_token` or log in with "
428
"`huggingface-cli login` and pass `use_auth_token=True`."
429
)
430
except RevisionNotFoundError:
431
raise EnvironmentError(
432
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
433
"for this model name. Check the model page at "
434
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
435
)
436
except LocalEntryNotFoundError:
437
# We try to see if we have a cached version (not up to date):
438
resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
439
if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
440
return resolved_file
441
if not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors:
442
return None
443
raise EnvironmentError(
444
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
445
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
446
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
447
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
448
)
449
except EntryNotFoundError:
450
if not _raise_exceptions_for_missing_entries:
451
return None
452
if revision is None:
453
revision = "main"
454
raise EnvironmentError(
455
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
456
f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
457
)
458
except HTTPError as err:
459
# First we try to see if we have a cached version (not up to date):
460
resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
461
if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
462
return resolved_file
463
if not _raise_exceptions_for_connection_errors:
464
return None
465
466
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
467
468
return resolved_file
469
470
471
def get_file_from_repo(
472
path_or_repo: Union[str, os.PathLike],
473
filename: str,
474
cache_dir: Optional[Union[str, os.PathLike]] = None,
475
force_download: bool = False,
476
resume_download: bool = False,
477
proxies: Optional[Dict[str, str]] = None,
478
use_auth_token: Optional[Union[bool, str]] = None,
479
revision: Optional[str] = None,
480
local_files_only: bool = False,
481
subfolder: str = "",
482
):
483
"""
484
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
485
486
Args:
487
path_or_repo (`str` or `os.PathLike`):
488
This can be either:
489
490
- a string, the *model id* of a model repo on huggingface.co.
491
- a path to a *directory* potentially containing the file.
492
filename (`str`):
493
The name of the file to locate in `path_or_repo`.
494
cache_dir (`str` or `os.PathLike`, *optional*):
495
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
496
cache should not be used.
497
force_download (`bool`, *optional*, defaults to `False`):
498
Whether or not to force to (re-)download the configuration files and override the cached versions if they
499
exist.
500
resume_download (`bool`, *optional*, defaults to `False`):
501
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
502
proxies (`Dict[str, str]`, *optional*):
503
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
504
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
505
use_auth_token (`str` or *bool*, *optional*):
506
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
507
when running `huggingface-cli login` (stored in `~/.huggingface`).
508
revision (`str`, *optional*, defaults to `"main"`):
509
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
510
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
511
identifier allowed by git.
512
local_files_only (`bool`, *optional*, defaults to `False`):
513
If `True`, will only try to load the tokenizer configuration from local files.
514
subfolder (`str`, *optional*, defaults to `""`):
515
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
516
specify the folder name here.
517
518
<Tip>
519
520
Passing `use_auth_token=True` is required when you want to use a private model.
521
522
</Tip>
523
524
Returns:
525
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo) or `None` if the
526
file does not exist.
527
528
Examples:
529
530
```python
531
# Download a tokenizer configuration from huggingface.co and cache.
532
tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json")
533
# This model does not have a tokenizer config so the result will be None.
534
tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
535
```"""
536
return cached_file(
537
path_or_repo_id=path_or_repo,
538
filename=filename,
539
cache_dir=cache_dir,
540
force_download=force_download,
541
resume_download=resume_download,
542
proxies=proxies,
543
use_auth_token=use_auth_token,
544
revision=revision,
545
local_files_only=local_files_only,
546
subfolder=subfolder,
547
_raise_exceptions_for_missing_entries=False,
548
_raise_exceptions_for_connection_errors=False,
549
)
550
551
552
def download_url(url, proxies=None):
553
"""
554
Downloads a given url in a temporary file. This function is not safe to use in multiple processes. Its only use is
555
for deprecated behavior allowing to download config/models with a single url instead of using the Hub.
556
557
Args:
558
url (`str`): The url of the file to download.
559
proxies (`Dict[str, str]`, *optional*):
560
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
561
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
562
563
Returns:
564
`str`: The location of the temporary file where the url was downloaded.
565
"""
566
warnings.warn(
567
f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in"
568
" v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note"
569
" that this is not compatible with the caching system (your file will be downloaded at each execution) or"
570
" multiple processes (each process will download the file in a different temporary file)."
571
)
572
tmp_file = tempfile.mktemp()
573
with open(tmp_file, "wb") as f:
574
http_get(url, f, proxies=proxies)
575
return tmp_file
576
577
578
def has_file(
579
path_or_repo: Union[str, os.PathLike],
580
filename: str,
581
revision: Optional[str] = None,
582
proxies: Optional[Dict[str, str]] = None,
583
use_auth_token: Optional[Union[bool, str]] = None,
584
):
585
"""
586
Checks if a repo contains a given file wihtout downloading it. Works for remote repos and local folders.
587
588
<Tip warning={false}>
589
590
This function will raise an error if the repository `path_or_repo` is not valid or if `revision` does not exist for
591
this repo, but will return False for regular connection errors.
592
593
</Tip>
594
"""
595
if os.path.isdir(path_or_repo):
596
return os.path.isfile(os.path.join(path_or_repo, filename))
597
598
url = hf_hub_url(path_or_repo, filename=filename, revision=revision)
599
600
headers = {"user-agent": http_user_agent()}
601
if isinstance(use_auth_token, str):
602
headers["authorization"] = f"Bearer {use_auth_token}"
603
elif use_auth_token:
604
token = HfFolder.get_token()
605
if token is None:
606
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
607
headers["authorization"] = f"Bearer {token}"
608
609
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
610
try:
611
hf_raise_for_status(r)
612
return True
613
except RepositoryNotFoundError as e:
614
logger.error(e)
615
raise EnvironmentError(f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'.")
616
except RevisionNotFoundError as e:
617
logger.error(e)
618
raise EnvironmentError(
619
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
620
f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
621
)
622
except requests.HTTPError:
623
# We return false for EntryNotFoundError (logical) as well as any connection error.
624
return False
625
626
627
class PushToHubMixin:
628
"""
629
A Mixin containing the functionality to push a model or tokenizer to the hub.
630
"""
631
632
def _create_repo(
633
self,
634
repo_id: str,
635
private: Optional[bool] = None,
636
use_auth_token: Optional[Union[bool, str]] = None,
637
repo_url: Optional[str] = None,
638
organization: Optional[str] = None,
639
):
640
"""
641
Create the repo if needed, cleans up repo_id with deprecated kwards `repo_url` and `organization`, retrives the
642
token.
643
"""
644
if repo_url is not None:
645
warnings.warn(
646
"The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` "
647
"instead."
648
)
649
repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "")
650
if organization is not None:
651
warnings.warn(
652
"The `organization` argument is deprecated and will be removed in v5 of Transformers. Set your "
653
"organization directly in the `repo_id` passed instead (`repo_id={organization}/{model_id}`)."
654
)
655
if not repo_id.startswith(organization):
656
if "/" in repo_id:
657
repo_id = repo_id.split("/")[-1]
658
repo_id = f"{organization}/{repo_id}"
659
660
token = HfFolder.get_token() if use_auth_token is True else use_auth_token
661
url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
662
663
# If the namespace is not there, add it or `upload_file` will complain
664
if "/" not in repo_id and url != f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{repo_id}":
665
repo_id = get_full_repo_name(repo_id, token=token)
666
return repo_id, token
667
668
def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]):
669
"""
670
Returns the list of files with their last modification timestamp.
671
"""
672
return {f: os.path.getmtime(os.path.join(working_dir, f)) for f in os.listdir(working_dir)}
673
674
def _upload_modified_files(
675
self,
676
working_dir: Union[str, os.PathLike],
677
repo_id: str,
678
files_timestamps: Dict[str, float],
679
commit_message: Optional[str] = None,
680
token: Optional[str] = None,
681
create_pr: bool = False,
682
):
683
"""
684
Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
685
"""
686
if commit_message is None:
687
if "Model" in self.__class__.__name__:
688
commit_message = "Upload model"
689
elif "Config" in self.__class__.__name__:
690
commit_message = "Upload config"
691
elif "Tokenizer" in self.__class__.__name__:
692
commit_message = "Upload tokenizer"
693
elif "FeatureExtractor" in self.__class__.__name__:
694
commit_message = "Upload feature extractor"
695
elif "Processor" in self.__class__.__name__:
696
commit_message = "Upload processor"
697
else:
698
commit_message = f"Upload {self.__class__.__name__}"
699
modified_files = [
700
f
701
for f in os.listdir(working_dir)
702
if f not in files_timestamps or os.path.getmtime(os.path.join(working_dir, f)) > files_timestamps[f]
703
]
704
operations = []
705
for file in modified_files:
706
operations.append(CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file))
707
logger.info(f"Uploading the following files to {repo_id}: {','.join(modified_files)}")
708
return create_commit(
709
repo_id=repo_id, operations=operations, commit_message=commit_message, token=token, create_pr=create_pr
710
)
711
712
def push_to_hub(
713
self,
714
repo_id: str,
715
use_temp_dir: Optional[bool] = None,
716
commit_message: Optional[str] = None,
717
private: Optional[bool] = None,
718
use_auth_token: Optional[Union[bool, str]] = None,
719
max_shard_size: Optional[Union[int, str]] = "10GB",
720
create_pr: bool = False,
721
**deprecated_kwargs
722
) -> str:
723
"""
724
Upload the {object_files} to the 🤗 Model Hub while synchronizing a local clone of the repo in
725
`repo_path_or_name`.
726
727
Parameters:
728
repo_id (`str`):
729
The name of the repository you want to push your {object} to. It should contain your organization name
730
when pushing to a given organization.
731
use_temp_dir (`bool`, *optional*):
732
Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
733
Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
734
commit_message (`str`, *optional*):
735
Message to commit while pushing. Will default to `"Upload {object}"`.
736
private (`bool`, *optional*):
737
Whether or not the repository created should be private (requires a paying subscription).
738
use_auth_token (`bool` or `str`, *optional*):
739
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
740
when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
741
is not specified.
742
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
743
Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
744
will then be each of size lower than this size. If expressed as a string, needs to be digits followed
745
by a unit (like `"5MB"`).
746
create_pr (`bool`, *optional*, defaults to `False`):
747
Whether or not to create a PR with the uploaded files or directly commit.
748
749
Examples:
750
751
```python
752
from transformers import {object_class}
753
754
{object} = {object_class}.from_pretrained("bert-base-cased")
755
756
# Push the {object} to your namespace with the name "my-finetuned-bert".
757
{object}.push_to_hub("my-finetuned-bert")
758
759
# Push the {object} to an organization with the name "my-finetuned-bert".
760
{object}.push_to_hub("huggingface/my-finetuned-bert")
761
```
762
"""
763
if "repo_path_or_name" in deprecated_kwargs:
764
warnings.warn(
765
"The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
766
"`repo_id` instead."
767
)
768
repo_id = deprecated_kwargs.pop("repo_path_or_name")
769
# Deprecation warning will be sent after for repo_url and organization
770
repo_url = deprecated_kwargs.pop("repo_url", None)
771
organization = deprecated_kwargs.pop("organization", None)
772
773
if os.path.isdir(repo_id):
774
working_dir = repo_id
775
repo_id = repo_id.split(os.path.sep)[-1]
776
else:
777
working_dir = repo_id.split("/")[-1]
778
779
repo_id, token = self._create_repo(
780
repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization
781
)
782
783
if use_temp_dir is None:
784
use_temp_dir = not os.path.isdir(working_dir)
785
786
with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
787
files_timestamps = self._get_files_timestamps(work_dir)
788
789
# Save all files.
790
self.save_pretrained(work_dir, max_shard_size=max_shard_size)
791
792
return self._upload_modified_files(
793
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=token, create_pr=create_pr
794
)
795
796
797
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
798
if token is None:
799
token = HfFolder.get_token()
800
if organization is None:
801
username = whoami(token)["name"]
802
return f"{username}/{model_id}"
803
else:
804
return f"{organization}/{model_id}"
805
806
807
def send_example_telemetry(example_name, *example_args, framework="pytorch"):
808
"""
809
Sends telemetry that helps tracking the examples use.
810
811
Args:
812
example_name (`str`): The name of the example.
813
*example_args (dataclasses or `argparse.ArgumentParser`): The arguments to the script. This function will only
814
try to extract the model and dataset name from those. Nothing else is tracked.
815
framework (`str`, *optional*, defaults to `"pytorch"`): The framework for the example.
816
"""
817
if is_offline_mode():
818
return
819
820
data = {"example": example_name, "framework": framework}
821
for args in example_args:
822
args_as_dict = {k: v for k, v in args.__dict__.items() if not k.startswith("_") and v is not None}
823
if "model_name_or_path" in args_as_dict:
824
model_name = args_as_dict["model_name_or_path"]
825
# Filter out local paths
826
if not os.path.isdir(model_name):
827
data["model_name"] = args_as_dict["model_name_or_path"]
828
if "dataset_name" in args_as_dict:
829
data["dataset_name"] = args_as_dict["dataset_name"]
830
elif "task_name" in args_as_dict:
831
# Extract script name from the example_name
832
script_name = example_name.replace("tf_", "").replace("flax_", "").replace("run_", "")
833
script_name = script_name.replace("_no_trainer", "")
834
data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}"
835
836
headers = {"user-agent": http_user_agent(data)}
837
try:
838
r = requests.head(HUGGINGFACE_CO_EXAMPLES_TELEMETRY, headers=headers)
839
r.raise_for_status()
840
except Exception:
841
# We don't want to error in case of connection errors of any kind.
842
pass
843
844
845
def convert_file_size_to_int(size: Union[int, str]):
846
"""
847
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
848
849
Args:
850
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
851
852
Example:
853
```py
854
>>> convert_file_size_to_int("1MiB")
855
1048576
856
```
857
"""
858
if isinstance(size, int):
859
return size
860
if size.upper().endswith("GIB"):
861
return int(size[:-3]) * (2**30)
862
if size.upper().endswith("MIB"):
863
return int(size[:-3]) * (2**20)
864
if size.upper().endswith("KIB"):
865
return int(size[:-3]) * (2**10)
866
if size.upper().endswith("GB"):
867
int_size = int(size[:-2]) * (10**9)
868
return int_size // 8 if size.endswith("b") else int_size
869
if size.upper().endswith("MB"):
870
int_size = int(size[:-2]) * (10**6)
871
return int_size // 8 if size.endswith("b") else int_size
872
if size.upper().endswith("KB"):
873
int_size = int(size[:-2]) * (10**3)
874
return int_size // 8 if size.endswith("b") else int_size
875
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
876
877
878
def get_checkpoint_shard_files(
879
pretrained_model_name_or_path,
880
index_filename,
881
cache_dir=None,
882
force_download=False,
883
proxies=None,
884
resume_download=False,
885
local_files_only=False,
886
use_auth_token=None,
887
user_agent=None,
888
revision=None,
889
subfolder="",
890
_commit_hash=None,
891
):
892
"""
893
For a given model:
894
895
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
896
Hub
897
- returns the list of paths to all the shards, as well as some metadata.
898
899
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
900
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
901
"""
902
import json
903
904
if not os.path.isfile(index_filename):
905
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
906
907
with open(index_filename, "r") as f:
908
index = json.loads(f.read())
909
910
shard_filenames = sorted(list(set(index["weight_map"].values())))
911
sharded_metadata = index["metadata"]
912
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
913
914
# First, let's deal with local folder.
915
if os.path.isdir(pretrained_model_name_or_path):
916
shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
917
return shard_filenames, sharded_metadata
918
919
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
920
cached_filenames = []
921
for shard_filename in shard_filenames:
922
try:
923
# Load from URL
924
cached_filename = cached_file(
925
pretrained_model_name_or_path,
926
shard_filename,
927
cache_dir=cache_dir,
928
force_download=force_download,
929
proxies=proxies,
930
resume_download=resume_download,
931
local_files_only=local_files_only,
932
use_auth_token=use_auth_token,
933
user_agent=user_agent,
934
revision=revision,
935
subfolder=subfolder,
936
_commit_hash=_commit_hash,
937
)
938
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
939
# we don't have to catch them here.
940
except EntryNotFoundError:
941
raise EnvironmentError(
942
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
943
"required according to the checkpoint index."
944
)
945
except HTTPError:
946
raise EnvironmentError(
947
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
948
" again after checking your internet connection."
949
)
950
951
cached_filenames.append(cached_filename)
952
953
return cached_filenames, sharded_metadata
954
955
956
# All what is below is for conversion between old cache format and new cache format.
957
958
959
def get_all_cached_files(cache_dir=None):
960
"""
961
Returns a list for all files cached with appropriate metadata.
962
"""
963
if cache_dir is None:
964
cache_dir = TRANSFORMERS_CACHE
965
else:
966
cache_dir = str(cache_dir)
967
if not os.path.isdir(cache_dir):
968
return []
969
970
cached_files = []
971
for file in os.listdir(cache_dir):
972
meta_path = os.path.join(cache_dir, f"{file}.json")
973
if not os.path.isfile(meta_path):
974
continue
975
976
with open(meta_path, encoding="utf-8") as meta_file:
977
metadata = json.load(meta_file)
978
url = metadata["url"]
979
etag = metadata["etag"].replace('"', "")
980
cached_files.append({"file": file, "url": url, "etag": etag})
981
982
return cached_files
983
984
985
def extract_info_from_url(url):
986
"""
987
Extract repo_name, revision and filename from an url.
988
"""
989
search = re.search(r"^https://huggingface\.co/(.*)/resolve/([^/]*)/(.*)$", url)
990
if search is None:
991
return None
992
repo, revision, filename = search.groups()
993
cache_repo = "--".join(["models"] + repo.split("/"))
994
return {"repo": cache_repo, "revision": revision, "filename": filename}
995
996
997
def clean_files_for(file):
998
"""
999
Remove, if they exist, file, file.json and file.lock
1000
"""
1001
for f in [file, f"{file}.json", f"{file}.lock"]:
1002
if os.path.isfile(f):
1003
os.remove(f)
1004
1005
1006
def move_to_new_cache(file, repo, filename, revision, etag, commit_hash):
1007
"""
1008
Move file to repo following the new huggingface hub cache organization.
1009
"""
1010
os.makedirs(repo, exist_ok=True)
1011
1012
# refs
1013
os.makedirs(os.path.join(repo, "refs"), exist_ok=True)
1014
if revision != commit_hash:
1015
ref_path = os.path.join(repo, "refs", revision)
1016
with open(ref_path, "w") as f:
1017
f.write(commit_hash)
1018
1019
# blobs
1020
os.makedirs(os.path.join(repo, "blobs"), exist_ok=True)
1021
blob_path = os.path.join(repo, "blobs", etag)
1022
shutil.move(file, blob_path)
1023
1024
# snapshots
1025
os.makedirs(os.path.join(repo, "snapshots"), exist_ok=True)
1026
os.makedirs(os.path.join(repo, "snapshots", commit_hash), exist_ok=True)
1027
pointer_path = os.path.join(repo, "snapshots", commit_hash, filename)
1028
huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
1029
clean_files_for(file)
1030
1031
1032
def move_cache(cache_dir=None, new_cache_dir=None, token=None):
1033
if new_cache_dir is None:
1034
new_cache_dir = TRANSFORMERS_CACHE
1035
if cache_dir is None:
1036
# Migrate from old cache in .cache/huggingface/hub
1037
old_cache = Path(TRANSFORMERS_CACHE).parent / "transformers"
1038
if os.path.isdir(str(old_cache)):
1039
cache_dir = str(old_cache)
1040
else:
1041
cache_dir = new_cache_dir
1042
if token is None:
1043
token = HfFolder.get_token()
1044
cached_files = get_all_cached_files(cache_dir=cache_dir)
1045
#print(f"Moving {len(cached_files)} files to the new cache system")
1046
1047
hub_metadata = {}
1048
for file_info in cached_files:
1049
url = file_info.pop("url")
1050
if url not in hub_metadata:
1051
try:
1052
hub_metadata[url] = get_hf_file_metadata(url, use_auth_token=token)
1053
except requests.HTTPError:
1054
continue
1055
1056
etag, commit_hash = hub_metadata[url].etag, hub_metadata[url].commit_hash
1057
if etag is None or commit_hash is None:
1058
continue
1059
1060
if file_info["etag"] != etag:
1061
# Cached file is not up to date, we just throw it as a new version will be downloaded anyway.
1062
clean_files_for(os.path.join(cache_dir, file_info["file"]))
1063
continue
1064
1065
url_info = extract_info_from_url(url)
1066
if url_info is None:
1067
# Not a file from huggingface.co
1068
continue
1069
1070
repo = os.path.join(new_cache_dir, url_info["repo"])
1071
move_to_new_cache(
1072
file=os.path.join(cache_dir, file_info["file"]),
1073
repo=repo,
1074
filename=url_info["filename"],
1075
revision=url_info["revision"],
1076
etag=etag,
1077
commit_hash=commit_hash,
1078
)
1079
1080
1081
cache_version_file = os.path.join(TRANSFORMERS_CACHE, "version.txt")
1082
if not os.path.isfile(cache_version_file):
1083
cache_version = 0
1084
else:
1085
with open(cache_version_file) as f:
1086
cache_version = int(f.read())
1087
1088
cache_is_not_empty = os.path.isdir(TRANSFORMERS_CACHE) and len(os.listdir(TRANSFORMERS_CACHE)) > 0
1089
1090
if cache_version < 1 and cache_is_not_empty:
1091
if is_offline_mode():
1092
logger.warning(
1093
"You are offline and the cache for model files in Transformers v4.22.0 has been updated while your local "
1094
"cache seems to be the one of a previous version. It is very likely that all your calls to any "
1095
"`from_pretrained()` method will fail. Remove the offline mode and enable internet connection to have "
1096
"your cache be updated automatically, then you can go back to offline mode."
1097
)
1098
try:
1099
if TRANSFORMERS_CACHE != default_cache_path:
1100
# Users set some env variable to customize cache storage
1101
move_cache(TRANSFORMERS_CACHE, TRANSFORMERS_CACHE)
1102
else:
1103
move_cache()
1104
except Exception as e:
1105
trace = "\n".join(traceback.format_tb(e.__traceback__))
1106
logger.error(
1107
f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
1108
"file an issue at https://github.com/huggingface/transformers/issues/new/choose and copy paste this whole "
1109
"message and we will do our best to help."
1110
)
1111
1112
try:
1113
os.makedirs(TRANSFORMERS_CACHE, exist_ok=True)
1114
with open(cache_version_file, "w") as f:
1115
f.write("1")
1116
except Exception:
1117
logger.warning(
1118
f"There was a problem when trying to write in your cache folder ({TRANSFORMERS_CACHE}). You should set "
1119
"the environment variable TRANSFORMERS_CACHE to a writable directory."
1120
)
1121
1122