Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/minBERT/config.py
984 views
1
from typing import Union, Tuple, Dict, Any, Optional
2
import os
3
import json
4
from collections import OrderedDict
5
import torch
6
from utils import CONFIG_NAME, hf_bucket_url, cached_path, is_remote_url
7
8
class PretrainedConfig(object):
9
model_type: str = ""
10
is_composition: bool = False
11
12
def __init__(self, **kwargs):
13
# Attributes with defaults
14
self.return_dict = kwargs.pop("return_dict", True)
15
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
16
self.output_attentions = kwargs.pop("output_attentions", False)
17
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
18
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
19
self.pruned_heads = kwargs.pop("pruned_heads", {})
20
self.tie_word_embeddings = kwargs.pop(
21
"tie_word_embeddings", True
22
) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
23
24
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
25
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
26
self.is_decoder = kwargs.pop("is_decoder", False)
27
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
28
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
29
30
# Parameters for sequence generation
31
self.max_length = kwargs.pop("max_length", 20)
32
self.min_length = kwargs.pop("min_length", 0)
33
self.do_sample = kwargs.pop("do_sample", False)
34
self.early_stopping = kwargs.pop("early_stopping", False)
35
self.num_beams = kwargs.pop("num_beams", 1)
36
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
37
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
38
self.temperature = kwargs.pop("temperature", 1.0)
39
self.top_k = kwargs.pop("top_k", 50)
40
self.top_p = kwargs.pop("top_p", 1.0)
41
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
42
self.length_penalty = kwargs.pop("length_penalty", 1.0)
43
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
44
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
45
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
46
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
47
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
48
self.output_scores = kwargs.pop("output_scores", False)
49
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
50
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
51
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
52
53
# Fine-tuning task arguments
54
self.architectures = kwargs.pop("architectures", None)
55
self.finetuning_task = kwargs.pop("finetuning_task", None)
56
self.id2label = kwargs.pop("id2label", None)
57
self.label2id = kwargs.pop("label2id", None)
58
if self.id2label is not None:
59
kwargs.pop("num_labels", None)
60
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
61
# Keys are always strings in JSON so convert ids to int here.
62
else:
63
self.num_labels = kwargs.pop("num_labels", 2)
64
65
# Tokenizer arguments
66
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
67
self.prefix = kwargs.pop("prefix", None)
68
self.bos_token_id = kwargs.pop("bos_token_id", None)
69
self.pad_token_id = kwargs.pop("pad_token_id", None)
70
self.eos_token_id = kwargs.pop("eos_token_id", None)
71
self.sep_token_id = kwargs.pop("sep_token_id", None)
72
73
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
74
75
# task specific arguments
76
self.task_specific_params = kwargs.pop("task_specific_params", None)
77
78
# TPU arguments
79
self.xla_device = kwargs.pop("xla_device", None)
80
81
# Name or path to the pretrained checkpoint
82
self._name_or_path = str(kwargs.pop("name_or_path", ""))
83
84
# Drop the transformers version info
85
kwargs.pop("transformers_version", None)
86
87
# Additional attributes without default values
88
for key, value in kwargs.items():
89
try:
90
setattr(self, key, value)
91
except AttributeError as err:
92
raise err
93
94
@classmethod
95
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
96
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
97
return cls.from_dict(config_dict, **kwargs)
98
99
@classmethod
100
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
101
with open(json_file, "r", encoding="utf-8") as reader:
102
text = reader.read()
103
return json.loads(text)
104
105
@classmethod
106
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
107
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
108
109
config = cls(**config_dict)
110
111
if hasattr(config, "pruned_heads"):
112
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
113
114
# Update config with kwargs if needed
115
to_remove = []
116
for key, value in kwargs.items():
117
if hasattr(config, key):
118
setattr(config, key, value)
119
to_remove.append(key)
120
for key in to_remove:
121
kwargs.pop(key, None)
122
123
if return_unused_kwargs:
124
return config, kwargs
125
else:
126
return config
127
128
@classmethod
129
def get_config_dict(
130
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
131
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
132
cache_dir = kwargs.pop("cache_dir", None)
133
force_download = kwargs.pop("force_download", False)
134
resume_download = kwargs.pop("resume_download", False)
135
proxies = kwargs.pop("proxies", None)
136
use_auth_token = kwargs.pop("use_auth_token", None)
137
local_files_only = kwargs.pop("local_files_only", False)
138
revision = kwargs.pop("revision", None)
139
140
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
141
if os.path.isdir(pretrained_model_name_or_path):
142
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
143
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
144
config_file = pretrained_model_name_or_path
145
else:
146
config_file = hf_bucket_url(
147
pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
148
)
149
150
try:
151
# Load from URL or cache if already cached
152
resolved_config_file = cached_path(
153
config_file,
154
cache_dir=cache_dir,
155
force_download=force_download,
156
proxies=proxies,
157
resume_download=resume_download,
158
local_files_only=local_files_only,
159
use_auth_token=use_auth_token,
160
)
161
# Load config dict
162
config_dict = cls._dict_from_json_file(resolved_config_file)
163
164
except EnvironmentError as err:
165
msg = (
166
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
167
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
168
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
169
)
170
raise EnvironmentError(msg)
171
172
except json.JSONDecodeError:
173
msg = (
174
"Couldn't reach server at '{}' to download configuration file or "
175
"configuration file is not a valid JSON file. "
176
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
177
)
178
raise EnvironmentError(msg)
179
180
return config_dict, kwargs
181
182
183
class BertConfig(PretrainedConfig):
184
model_type = "bert"
185
186
def __init__(
187
self,
188
vocab_size=30522,
189
hidden_size=768,
190
num_hidden_layers=12,
191
num_attention_heads=12,
192
intermediate_size=3072,
193
hidden_act="gelu",
194
hidden_dropout_prob=0.1,
195
attention_probs_dropout_prob=0.1,
196
max_position_embeddings=512,
197
type_vocab_size=2,
198
initializer_range=0.02,
199
layer_norm_eps=1e-12,
200
pad_token_id=0,
201
gradient_checkpointing=False,
202
position_embedding_type="absolute",
203
use_cache=True,
204
**kwargs
205
):
206
super().__init__(pad_token_id=pad_token_id, **kwargs)
207
208
self.vocab_size = vocab_size
209
self.hidden_size = hidden_size
210
self.num_hidden_layers = num_hidden_layers
211
self.num_attention_heads = num_attention_heads
212
self.hidden_act = hidden_act
213
self.intermediate_size = intermediate_size
214
self.hidden_dropout_prob = hidden_dropout_prob
215
self.attention_probs_dropout_prob = attention_probs_dropout_prob
216
self.max_position_embeddings = max_position_embeddings
217
self.type_vocab_size = type_vocab_size
218
self.initializer_range = initializer_range
219
self.layer_norm_eps = layer_norm_eps
220
self.gradient_checkpointing = gradient_checkpointing
221
self.position_embedding_type = position_embedding_type
222
self.use_cache = use_cache
223
224