from typing import Dict, List, Optional, Union, Tuple, BinaryIO
import os
import sys
import json
import tempfile
import copy
import fnmatch
from tqdm.auto import tqdm
from functools import partial
from urllib.parse import urlparse
from pathlib import Path
import requests
from hashlib import sha256
from filelock import FileLock
import importlib_metadata
import torch
import torch.nn as nn
from torch import Tensor
__version__ = "4.0.0"
_torch_version = importlib_metadata.version("torch")
hf_cache_home = os.path.expanduser(os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")))
default_cache_path = os.path.join(hf_cache_home, "transformers")
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
PRESET_MIRROR_DICT = {
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
}
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
WEIGHTS_NAME = "pytorch_model.bin"
CONFIG_NAME = "config.json"
def is_torch_available():
return True
def is_tf_available():
return False
def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
headers = copy.deepcopy(headers)
if resume_size > 0:
headers["Range"] = "bytes=%d-" % (resume_size,)
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
r.raise_for_status()
content_length = r.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(
unit="B",
unit_scale=True,
total=total,
initial=resume_size,
desc="Downloading",
disable=False,
)
for chunk in r.iter_content(chunk_size=1024):
if chunk:
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
url_bytes = url.encode("utf-8")
filename = sha256(url_bytes).hexdigest()
if etag:
etag_bytes = etag.encode("utf-8")
filename += "." + sha256(etag_bytes).hexdigest()
if url.endswith(".h5"):
filename += ".h5"
return filename
def hf_bucket_url(
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
) -> str:
if subfolder is not None:
filename = f"{subfolder}/{filename}"
if mirror:
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
legacy_format = "/" not in model_id
if legacy_format:
return f"{endpoint}/{model_id}-{filename}"
else:
return f"{endpoint}/{model_id}/{filename}"
if revision is None:
revision = "main"
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if is_torch_available():
ua += f"; torch/{_torch_version}"
if is_tf_available():
ua += f"; tensorflow/{_tf_version}"
if isinstance(user_agent, dict):
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
return ua
def get_from_cache(
url: str,
cache_dir=None,
force_download=False,
proxies=None,
etag_timeout=10,
resume_download=False,
user_agent: Union[Dict, str, None] = None,
use_auth_token: Union[bool, str, None] = None,
local_files_only=False,
) -> Optional[str]:
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
headers = {"user-agent": http_user_agent(user_agent)}
if isinstance(use_auth_token, str):
headers["authorization"] = "Bearer {}".format(use_auth_token)
elif use_auth_token:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
headers["authorization"] = "Bearer {}".format(token)
url_to_download = url
etag = None
if not local_files_only:
try:
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
r.raise_for_status()
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
if etag is None:
raise OSError(
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
)
if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"]
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
pass
filename = url_to_filename(url, etag)
cache_path = os.path.join(cache_dir, filename)
if etag is None:
if os.path.exists(cache_path):
return cache_path
else:
matching_files = [
file
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
if not file.endswith(".json") and not file.endswith(".lock")
]
if len(matching_files) > 0:
return os.path.join(cache_dir, matching_files[-1])
else:
if local_files_only:
raise FileNotFoundError(
"Cannot find the requested files in the cached path and outgoing traffic has been"
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False."
)
else:
raise ValueError(
"Connection error, and we cannot find the requested files in the cached path."
" Please try again or make sure your Internet connection is on."
)
if os.path.exists(cache_path) and not force_download:
return cache_path
lock_path = cache_path + ".lock"
with FileLock(lock_path):
if os.path.exists(cache_path) and not force_download:
return cache_path
if resume_download:
incomplete_path = cache_path + ".incomplete"
@contextmanager
def _resumable_file_manager() -> "io.BufferedWriter":
with open(incomplete_path, "ab") as f:
yield f
temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
else:
resume_size = 0
else:
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
resume_size = 0
with temp_file_manager() as temp_file:
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
os.replace(temp_file.name, cache_path)
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
json.dump(meta, meta_file)
return cache_path
def cached_path(
url_or_filename,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
user_agent: Union[Dict, str, None] = None,
extract_compressed_file=False,
force_extract=False,
use_auth_token: Union[bool, str, None] = None,
local_files_only=False,
) -> Optional[str]:
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if is_remote_url(url_or_filename):
output_path = get_from_cache(
url_or_filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
user_agent=user_agent,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
elif os.path.exists(url_or_filename):
output_path = url_or_filename
elif urlparse(url_or_filename).scheme == "":
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
if extract_compressed_file:
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
return output_path
output_dir, output_file = os.path.split(output_path)
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
return output_path_extracted
lock_path = output_path + ".lock"
with FileLock(lock_path):
shutil.rmtree(output_path_extracted, ignore_errors=True)
os.makedirs(output_path_extracted)
if is_zipfile(output_path):
with ZipFile(output_path, "r") as zip_file:
zip_file.extractall(output_path_extracted)
zip_file.close()
elif tarfile.is_tarfile(output_path):
tar_file = tarfile.open(output_path)
tar_file.extractall(output_path_extracted)
tar_file.close()
else:
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
return output_path_extracted
return output_path
def get_parameter_dtype(parameter: Union[nn.Module]):
try:
return next(parameter.parameters()).dtype
except StopIteration:
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
def get_extended_attention_mask(attention_mask: Tensor, dtype) -> Tensor:
assert attention_mask.dim() == 2
extended_attention_mask = attention_mask[:, None, None, :]
extended_attention_mask = extended_attention_mask.to(dtype=dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask