Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/minBERT/utils.py
984 views
1
from typing import Dict, List, Optional, Union, Tuple, BinaryIO
2
import os
3
import sys
4
import json
5
import tempfile
6
import copy
7
import fnmatch
8
from tqdm.auto import tqdm
9
from functools import partial
10
from urllib.parse import urlparse
11
from pathlib import Path
12
import requests
13
from hashlib import sha256
14
from filelock import FileLock
15
import importlib_metadata
16
import torch
17
import torch.nn as nn
18
from torch import Tensor
19
20
__version__ = "4.0.0"
21
_torch_version = importlib_metadata.version("torch")
22
23
hf_cache_home = os.path.expanduser(os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")))
24
default_cache_path = os.path.join(hf_cache_home, "transformers")
25
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
26
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
27
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
28
29
PRESET_MIRROR_DICT = {
30
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
31
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
32
}
33
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
34
WEIGHTS_NAME = "pytorch_model.bin"
35
CONFIG_NAME = "config.json"
36
37
38
def is_torch_available():
39
return True
40
41
42
def is_tf_available():
43
return False
44
45
46
def is_remote_url(url_or_filename):
47
parsed = urlparse(url_or_filename)
48
return parsed.scheme in ("http", "https")
49
50
51
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
52
headers = copy.deepcopy(headers)
53
if resume_size > 0:
54
headers["Range"] = "bytes=%d-" % (resume_size,)
55
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
56
r.raise_for_status()
57
content_length = r.headers.get("Content-Length")
58
total = resume_size + int(content_length) if content_length is not None else None
59
progress = tqdm(
60
unit="B",
61
unit_scale=True,
62
total=total,
63
initial=resume_size,
64
desc="Downloading",
65
disable=False,
66
)
67
for chunk in r.iter_content(chunk_size=1024):
68
if chunk: # filter out keep-alive new chunks
69
progress.update(len(chunk))
70
temp_file.write(chunk)
71
progress.close()
72
73
74
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
75
url_bytes = url.encode("utf-8")
76
filename = sha256(url_bytes).hexdigest()
77
78
if etag:
79
etag_bytes = etag.encode("utf-8")
80
filename += "." + sha256(etag_bytes).hexdigest()
81
82
if url.endswith(".h5"):
83
filename += ".h5"
84
85
return filename
86
87
88
def hf_bucket_url(
89
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
90
) -> str:
91
if subfolder is not None:
92
filename = f"{subfolder}/{filename}"
93
94
if mirror:
95
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
96
legacy_format = "/" not in model_id
97
if legacy_format:
98
return f"{endpoint}/{model_id}-{filename}"
99
else:
100
return f"{endpoint}/{model_id}/{filename}"
101
102
if revision is None:
103
revision = "main"
104
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
105
106
107
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
108
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
109
if is_torch_available():
110
ua += f"; torch/{_torch_version}"
111
if is_tf_available():
112
ua += f"; tensorflow/{_tf_version}"
113
if isinstance(user_agent, dict):
114
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
115
elif isinstance(user_agent, str):
116
ua += "; " + user_agent
117
return ua
118
119
120
def get_from_cache(
121
url: str,
122
cache_dir=None,
123
force_download=False,
124
proxies=None,
125
etag_timeout=10,
126
resume_download=False,
127
user_agent: Union[Dict, str, None] = None,
128
use_auth_token: Union[bool, str, None] = None,
129
local_files_only=False,
130
) -> Optional[str]:
131
if cache_dir is None:
132
cache_dir = TRANSFORMERS_CACHE
133
if isinstance(cache_dir, Path):
134
cache_dir = str(cache_dir)
135
136
os.makedirs(cache_dir, exist_ok=True)
137
138
headers = {"user-agent": http_user_agent(user_agent)}
139
if isinstance(use_auth_token, str):
140
headers["authorization"] = "Bearer {}".format(use_auth_token)
141
elif use_auth_token:
142
token = HfFolder.get_token()
143
if token is None:
144
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
145
headers["authorization"] = "Bearer {}".format(token)
146
147
url_to_download = url
148
etag = None
149
if not local_files_only:
150
try:
151
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
152
r.raise_for_status()
153
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
154
# We favor a custom header indicating the etag of the linked resource, and
155
# we fallback to the regular etag header.
156
# If we don't have any of those, raise an error.
157
if etag is None:
158
raise OSError(
159
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
160
)
161
# In case of a redirect,
162
# save an extra redirect on the request.get call,
163
# and ensure we download the exact atomic version even if it changed
164
# between the HEAD and the GET (unlikely, but hey).
165
if 300 <= r.status_code <= 399:
166
url_to_download = r.headers["Location"]
167
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
168
# etag is already None
169
pass
170
171
filename = url_to_filename(url, etag)
172
173
# get cache path to put the file
174
cache_path = os.path.join(cache_dir, filename)
175
176
# etag is None == we don't have a connection or we passed local_files_only.
177
# try to get the last downloaded one
178
if etag is None:
179
if os.path.exists(cache_path):
180
return cache_path
181
else:
182
matching_files = [
183
file
184
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
185
if not file.endswith(".json") and not file.endswith(".lock")
186
]
187
if len(matching_files) > 0:
188
return os.path.join(cache_dir, matching_files[-1])
189
else:
190
# If files cannot be found and local_files_only=True,
191
# the models might've been found if local_files_only=False
192
# Notify the user about that
193
if local_files_only:
194
raise FileNotFoundError(
195
"Cannot find the requested files in the cached path and outgoing traffic has been"
196
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
197
" to False."
198
)
199
else:
200
raise ValueError(
201
"Connection error, and we cannot find the requested files in the cached path."
202
" Please try again or make sure your Internet connection is on."
203
)
204
205
# From now on, etag is not None.
206
if os.path.exists(cache_path) and not force_download:
207
return cache_path
208
209
# Prevent parallel downloads of the same file with a lock.
210
lock_path = cache_path + ".lock"
211
with FileLock(lock_path):
212
213
# If the download just completed while the lock was activated.
214
if os.path.exists(cache_path) and not force_download:
215
# Even if returning early like here, the lock will be released.
216
return cache_path
217
218
if resume_download:
219
incomplete_path = cache_path + ".incomplete"
220
221
@contextmanager
222
def _resumable_file_manager() -> "io.BufferedWriter":
223
with open(incomplete_path, "ab") as f:
224
yield f
225
226
temp_file_manager = _resumable_file_manager
227
if os.path.exists(incomplete_path):
228
resume_size = os.stat(incomplete_path).st_size
229
else:
230
resume_size = 0
231
else:
232
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
233
resume_size = 0
234
235
# Download to temporary file, then copy to cache dir once finished.
236
# Otherwise you get corrupt cache entries if the download gets interrupted.
237
with temp_file_manager() as temp_file:
238
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
239
240
os.replace(temp_file.name, cache_path)
241
242
meta = {"url": url, "etag": etag}
243
meta_path = cache_path + ".json"
244
with open(meta_path, "w") as meta_file:
245
json.dump(meta, meta_file)
246
247
return cache_path
248
249
250
def cached_path(
251
url_or_filename,
252
cache_dir=None,
253
force_download=False,
254
proxies=None,
255
resume_download=False,
256
user_agent: Union[Dict, str, None] = None,
257
extract_compressed_file=False,
258
force_extract=False,
259
use_auth_token: Union[bool, str, None] = None,
260
local_files_only=False,
261
) -> Optional[str]:
262
if cache_dir is None:
263
cache_dir = TRANSFORMERS_CACHE
264
if isinstance(url_or_filename, Path):
265
url_or_filename = str(url_or_filename)
266
if isinstance(cache_dir, Path):
267
cache_dir = str(cache_dir)
268
269
if is_remote_url(url_or_filename):
270
# URL, so get it from the cache (downloading if necessary)
271
output_path = get_from_cache(
272
url_or_filename,
273
cache_dir=cache_dir,
274
force_download=force_download,
275
proxies=proxies,
276
resume_download=resume_download,
277
user_agent=user_agent,
278
use_auth_token=use_auth_token,
279
local_files_only=local_files_only,
280
)
281
elif os.path.exists(url_or_filename):
282
# File, and it exists.
283
output_path = url_or_filename
284
elif urlparse(url_or_filename).scheme == "":
285
# File, but it doesn't exist.
286
raise EnvironmentError("file {} not found".format(url_or_filename))
287
else:
288
# Something unknown
289
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
290
291
if extract_compressed_file:
292
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
293
return output_path
294
295
# Path where we extract compressed archives
296
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
297
output_dir, output_file = os.path.split(output_path)
298
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
299
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
300
301
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
302
return output_path_extracted
303
304
# Prevent parallel extractions
305
lock_path = output_path + ".lock"
306
with FileLock(lock_path):
307
shutil.rmtree(output_path_extracted, ignore_errors=True)
308
os.makedirs(output_path_extracted)
309
if is_zipfile(output_path):
310
with ZipFile(output_path, "r") as zip_file:
311
zip_file.extractall(output_path_extracted)
312
zip_file.close()
313
elif tarfile.is_tarfile(output_path):
314
tar_file = tarfile.open(output_path)
315
tar_file.extractall(output_path_extracted)
316
tar_file.close()
317
else:
318
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
319
320
return output_path_extracted
321
322
return output_path
323
324
325
def get_parameter_dtype(parameter: Union[nn.Module]):
326
try:
327
return next(parameter.parameters()).dtype
328
except StopIteration:
329
# For nn.DataParallel compatibility in PyTorch 1.5
330
331
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
332
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
333
return tuples
334
335
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
336
first_tuple = next(gen)
337
return first_tuple[1].dtype
338
339
340
def get_extended_attention_mask(attention_mask: Tensor, dtype) -> Tensor:
341
# attention_mask [batch_size, seq_length]
342
assert attention_mask.dim() == 2
343
# [batch_size, 1, 1, seq_length] for multi-head attention
344
extended_attention_mask = attention_mask[:, None, None, :]
345
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
346
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
347
return extended_attention_mask
348