Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/text/learner.py
781 views
1
'Model training for NLP'
2
from ..torch_core import *
3
from ..basic_train import *
4
from ..callbacks import *
5
from ..data_block import CategoryList
6
from ..basic_data import *
7
from ..datasets import *
8
from ..metrics import accuracy
9
from ..train import GradientClipping
10
from ..layers import *
11
from .models import *
12
from .transform import *
13
from .data import *
14
15
__all__ = ['RNNLearner', 'LanguageLearner', 'convert_weights', 'decode_spec_tokens', 'get_language_model', 'language_model_learner',
16
'MultiBatchEncoder', 'get_text_classifier', 'text_classifier_learner', 'PoolingLinearClassifier']
17
18
_model_meta = {AWD_LSTM: {'hid_name':'emb_sz', 'url':URLs.WT103_FWD, 'url_bwd':URLs.WT103_BWD,
19
'config_lm':awd_lstm_lm_config, 'split_lm': awd_lstm_lm_split,
20
'config_clas':awd_lstm_clas_config, 'split_clas': awd_lstm_clas_split},
21
Transformer: {'hid_name':'d_model', 'url':URLs.OPENAI_TRANSFORMER,
22
'config_lm':tfmer_lm_config, 'split_lm': tfmer_lm_split,
23
'config_clas':tfmer_clas_config, 'split_clas': tfmer_clas_split},
24
TransformerXL: {'hid_name':'d_model',
25
'config_lm':tfmerXL_lm_config, 'split_lm': tfmerXL_lm_split,
26
'config_clas':tfmerXL_clas_config, 'split_clas': tfmerXL_clas_split}}
27
28
def convert_weights(wgts:Weights, stoi_wgts:Dict[str,int], itos_new:Collection[str]) -> Weights:
29
"Convert the model `wgts` to go with a new vocabulary."
30
dec_bias, enc_wgts = wgts.get('1.decoder.bias', None), wgts['0.encoder.weight']
31
wgts_m = enc_wgts.mean(0)
32
if dec_bias is not None: bias_m = dec_bias.mean(0)
33
new_w = enc_wgts.new_zeros((len(itos_new),enc_wgts.size(1))).zero_()
34
if dec_bias is not None: new_b = dec_bias.new_zeros((len(itos_new),)).zero_()
35
for i,w in enumerate(itos_new):
36
r = stoi_wgts[w] if w in stoi_wgts else -1
37
new_w[i] = enc_wgts[r] if r>=0 else wgts_m
38
if dec_bias is not None: new_b[i] = dec_bias[r] if r>=0 else bias_m
39
wgts['0.encoder.weight'] = new_w
40
if '0.encoder_dp.emb.weight' in wgts: wgts['0.encoder_dp.emb.weight'] = new_w.clone()
41
wgts['1.decoder.weight'] = new_w.clone()
42
if dec_bias is not None: wgts['1.decoder.bias'] = new_b
43
return wgts
44
45
class RNNLearner(Learner):
46
"Basic class for a `Learner` in NLP."
47
def __init__(self, data:DataBunch, model:nn.Module, split_func:OptSplitFunc=None, clip:float=None,
48
alpha:float=2., beta:float=1., metrics=None, **learn_kwargs):
49
is_class = (hasattr(data.train_ds, 'y') and (isinstance(data.train_ds.y, CategoryList) or
50
isinstance(data.train_ds.y, LMLabelList)))
51
metrics = ifnone(metrics, ([accuracy] if is_class else []))
52
super().__init__(data, model, metrics=metrics, **learn_kwargs)
53
self.callbacks.append(RNNTrainer(self, alpha=alpha, beta=beta))
54
if clip: self.callback_fns.append(partial(GradientClipping, clip=clip))
55
if split_func: self.split(split_func)
56
57
def save_encoder(self, name:str):
58
"Save the encoder to `name` inside the model directory."
59
if is_pathlike(name): self._test_writeable_path()
60
encoder = get_model(self.model)[0]
61
if hasattr(encoder, 'module'): encoder = encoder.module
62
torch.save(encoder.state_dict(), self.path/self.model_dir/f'{name}.pth')
63
64
def load_encoder(self, name:str, device:torch.device=None):
65
"Load the encoder `name` from the model directory."
66
encoder = get_model(self.model)[0]
67
if device is None: device = self.data.device
68
if hasattr(encoder, 'module'): encoder = encoder.module
69
encoder.load_state_dict(torch.load(self.path/self.model_dir/f'{name}.pth', map_location=device))
70
self.freeze()
71
72
def load_pretrained(self, wgts_fname:str, itos_fname:str, strict:bool=True):
73
"Load a pretrained model and adapts it to the data vocabulary."
74
old_itos = pickle.load(open(itos_fname, 'rb'))
75
old_stoi = {v:k for k,v in enumerate(old_itos)}
76
wgts = torch.load(wgts_fname, map_location=lambda storage, loc: storage)
77
if 'model' in wgts: wgts = wgts['model']
78
wgts = convert_weights(wgts, old_stoi, self.data.train_ds.vocab.itos)
79
self.model.load_state_dict(wgts, strict=strict)
80
81
def get_preds(self, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, with_loss:bool=False, n_batch:Optional[int]=None,
82
pbar:Optional[PBar]=None, ordered:bool=False) -> List[Tensor]:
83
"Return predictions and targets on the valid, train, or test set, depending on `ds_type`."
84
self.model.reset()
85
if ordered: np.random.seed(42)
86
preds = super().get_preds(ds_type=ds_type, activ=activ, with_loss=with_loss, n_batch=n_batch, pbar=pbar)
87
if ordered and hasattr(self.dl(ds_type), 'sampler'):
88
np.random.seed(42)
89
sampler = [i for i in self.dl(ds_type).sampler]
90
reverse_sampler = np.argsort(sampler)
91
preds = [p[reverse_sampler] for p in preds]
92
return(preds)
93
94
def decode_spec_tokens(tokens):
95
new_toks,rule,arg = [],None,None
96
for t in tokens:
97
if t in [TK_MAJ, TK_UP, TK_REP, TK_WREP]: rule = t
98
elif rule is None: new_toks.append(t)
99
elif rule == TK_MAJ:
100
new_toks.append(t[:1].upper() + t[1:].lower())
101
rule = None
102
elif rule == TK_UP:
103
new_toks.append(t.upper())
104
rule = None
105
elif arg is None:
106
try: arg = int(t)
107
except: rule = None
108
else:
109
if rule == TK_REP: new_toks.append(t * arg)
110
else: new_toks += [t] * arg
111
return new_toks
112
113
class LanguageLearner(RNNLearner):
114
"Subclass of RNNLearner for predictions."
115
116
def predict(self, text:str, n_words:int=1, no_unk:bool=True, temperature:float=1., min_p:float=None, sep:str=' ',
117
decoder=decode_spec_tokens):
118
"Return the `n_words` that come after `text`."
119
ds = self.data.single_dl.dataset
120
self.model.reset()
121
xb,yb = self.data.one_item(text)
122
new_idx = []
123
for _ in range(n_words): #progress_bar(range(n_words), leave=False):
124
res = self.pred_batch(batch=(xb,yb))[0][-1]
125
#if len(new_idx) == 0: self.model[0].select_hidden([0])
126
if no_unk: res[self.data.vocab.stoi[UNK]] = 0.
127
if min_p is not None:
128
if (res >= min_p).float().sum() == 0:
129
warn(f"There is no item with probability >= {min_p}, try a lower value.")
130
else: res[res < min_p] = 0.
131
if temperature != 1.: res.pow_(1 / temperature)
132
idx = torch.multinomial(res, 1).item()
133
new_idx.append(idx)
134
xb = xb.new_tensor([idx])[None]
135
return text + sep + sep.join(decoder(self.data.vocab.textify(new_idx, sep=None)))
136
137
def beam_search(self, text:str, n_words:int, no_unk:bool=True, top_k:int=10, beam_sz:int=1000, temperature:float=1.,
138
sep:str=' ', decoder=decode_spec_tokens):
139
"Return the `n_words` that come after `text` using beam search."
140
ds = self.data.single_dl.dataset
141
self.model.reset()
142
self.model.eval()
143
xb, yb = self.data.one_item(text)
144
nodes = None
145
nodes = xb.clone()
146
scores = xb.new_zeros(1).float()
147
with torch.no_grad():
148
for k in progress_bar(range(n_words), leave=False):
149
out = F.log_softmax(self.model(xb)[0][:,-1], dim=-1)
150
if no_unk: out[:,self.data.vocab.stoi[UNK]] = -float('Inf')
151
values, indices = out.topk(top_k, dim=-1)
152
scores = (-values + scores[:,None]).view(-1)
153
indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), top_k).contiguous().view(-1)
154
sort_idx = scores.argsort()[:beam_sz]
155
scores = scores[sort_idx]
156
nodes = torch.cat([nodes[:,None].expand(nodes.size(0),top_k,nodes.size(1)),
157
indices[:,:,None].expand(nodes.size(0),top_k,1),], dim=2)
158
nodes = nodes.view(-1, nodes.size(2))[sort_idx]
159
self.model[0].select_hidden(indices_idx[sort_idx])
160
xb = nodes[:,-1][:,None]
161
if temperature != 1.: scores.div_(temperature)
162
node_idx = torch.multinomial(torch.exp(-scores), 1).item()
163
return text + sep + sep.join(decoder(self.data.vocab.textify([i.item() for i in nodes[node_idx][1:] ], sep=None)))
164
165
def show_results(self, ds_type=DatasetType.Valid, rows:int=5, max_len:int=20):
166
from IPython.display import display, HTML
167
"Show `rows` result of predictions on `ds_type` dataset."
168
ds = self.dl(ds_type).dataset
169
x,y = self.data.one_batch(ds_type, detach=False, denorm=False)
170
preds = self.pred_batch(batch=(x,y))
171
y = y.view(*x.size())
172
z = preds.view(*x.size(),-1).argmax(dim=2)
173
xs = [ds.x.reconstruct(grab_idx(x, i)) for i in range(rows)]
174
ys = [ds.x.reconstruct(grab_idx(y, i)) for i in range(rows)]
175
zs = [ds.x.reconstruct(grab_idx(z, i)) for i in range(rows)]
176
items,names = [],['text', 'target', 'pred']
177
for i, (x,y,z) in enumerate(zip(xs,ys,zs)):
178
txt_x = ' '.join(x.text.split(' ')[:max_len])
179
txt_y = ' '.join(y.text.split(' ')[max_len-1:2*max_len-1])
180
txt_z = ' '.join(z.text.split(' ')[max_len-1:2*max_len-1])
181
items.append([txt_x, txt_y, txt_z])
182
items = np.array(items)
183
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
184
with pd.option_context('display.max_colwidth', -1):
185
display(HTML(df.to_html(index=False)))
186
187
def get_language_model(arch:Callable, vocab_sz:int, config:dict=None, drop_mult:float=1.):
188
"Create a language model from `arch` and its `config`, maybe `pretrained`."
189
meta = _model_meta[arch]
190
config = ifnone(config, meta['config_lm']).copy()
191
for k in config.keys():
192
if k.endswith('_p'): config[k] *= drop_mult
193
tie_weights,output_p,out_bias = map(config.pop, ['tie_weights', 'output_p', 'out_bias'])
194
init = config.pop('init') if 'init' in config else None
195
encoder = arch(vocab_sz, **config)
196
enc = encoder.encoder if tie_weights else None
197
decoder = LinearDecoder(vocab_sz, config[meta['hid_name']], output_p, tie_encoder=enc, bias=out_bias)
198
model = SequentialRNN(encoder, decoder)
199
return model if init is None else model.apply(init)
200
201
def language_model_learner(data:DataBunch, arch, config:dict=None, drop_mult:float=1., pretrained:bool=True,
202
pretrained_fnames:OptStrTuple=None, **learn_kwargs) -> 'LanguageLearner':
203
"Create a `Learner` with a language model from `data` and `arch`."
204
model = get_language_model(arch, len(data.vocab.itos), config=config, drop_mult=drop_mult)
205
meta = _model_meta[arch]
206
learn = LanguageLearner(data, model, split_func=meta['split_lm'], **learn_kwargs)
207
url = 'url_bwd' if data.backwards else 'url'
208
if pretrained or pretrained_fnames:
209
if pretrained_fnames is not None:
210
fnames = [learn.path/learn.model_dir/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])]
211
else:
212
if url not in meta:
213
warn("There are no pretrained weights for that architecture yet!")
214
return learn
215
model_path = untar_data(meta[url] , data=False)
216
fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]
217
learn.load_pretrained(*fnames)
218
learn.freeze()
219
return learn
220
221
def masked_concat_pool(outputs, mask):
222
"Pool MultiBatchEncoder outputs into one vector [last_hidden, max_pool, avg_pool]."
223
output = outputs[-1]
224
avg_pool = output.masked_fill(mask[:, :, None], 0).mean(dim=1)
225
avg_pool *= output.size(1) / (output.size(1)-mask.type(avg_pool.dtype).sum(dim=1))[:,None]
226
max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0]
227
x = torch.cat([output[:,-1], max_pool, avg_pool], 1)
228
return x
229
230
class PoolingLinearClassifier(Module):
231
"Create a linear classifier with pooling."
232
def __init__(self, layers:Collection[int], drops:Collection[float]):
233
mod_layers = []
234
if len(drops) != len(layers)-1: raise ValueError("Number of layers and dropout values do not match.")
235
activs = [nn.ReLU(inplace=True)] * (len(layers) - 2) + [None]
236
for n_in, n_out, p, actn in zip(layers[:-1], layers[1:], drops, activs):
237
mod_layers += bn_drop_lin(n_in, n_out, p=p, actn=actn)
238
self.layers = nn.Sequential(*mod_layers)
239
240
def forward(self, input:Tuple[Tensor,Tensor, Tensor])->Tuple[Tensor,Tensor,Tensor]:
241
raw_outputs,outputs,mask = input
242
x = masked_concat_pool(outputs, mask)
243
x = self.layers(x)
244
return x, raw_outputs, outputs
245
246
class MultiBatchEncoder(Module):
247
"Create an encoder over `module` that can process a full sentence."
248
def __init__(self, bptt:int, max_len:int, module:nn.Module, pad_idx:int=1):
249
self.max_len,self.bptt,self.module,self.pad_idx = max_len,bptt,module,pad_idx
250
251
def concat(self, arrs:Collection[Tensor])->Tensor:
252
"Concatenate the `arrs` along the batch dimension."
253
return [torch.cat([l[si] for l in arrs], dim=1) for si in range_of(arrs[0])]
254
255
def reset(self):
256
if hasattr(self.module, 'reset'): self.module.reset()
257
258
def forward(self, input:LongTensor)->Tuple[Tensor,Tensor]:
259
bs,sl = input.size()
260
self.reset()
261
raw_outputs,outputs,masks = [],[],[]
262
for i in range(0, sl, self.bptt):
263
r, o = self.module(input[:,i: min(i+self.bptt, sl)])
264
if i>(sl-self.max_len):
265
masks.append(input[:,i: min(i+self.bptt, sl)] == self.pad_idx)
266
raw_outputs.append(r)
267
outputs.append(o)
268
return self.concat(raw_outputs),self.concat(outputs),torch.cat(masks,dim=1)
269
270
def get_text_classifier(arch:Callable, vocab_sz:int, n_class:int, bptt:int=70, max_len:int=20*70, config:dict=None,
271
drop_mult:float=1., lin_ftrs:Collection[int]=None, ps:Collection[float]=None,
272
pad_idx:int=1) -> nn.Module:
273
"Create a text classifier from `arch` and its `config`, maybe `pretrained`."
274
meta = _model_meta[arch]
275
config = ifnone(config, meta['config_clas']).copy()
276
for k in config.keys():
277
if k.endswith('_p'): config[k] *= drop_mult
278
if lin_ftrs is None: lin_ftrs = [50]
279
if ps is None: ps = [0.1]*len(lin_ftrs)
280
layers = [config[meta['hid_name']] * 3] + lin_ftrs + [n_class]
281
ps = [config.pop('output_p')] + ps
282
init = config.pop('init') if 'init' in config else None
283
encoder = MultiBatchEncoder(bptt, max_len, arch(vocab_sz, **config), pad_idx=pad_idx)
284
model = SequentialRNN(encoder, PoolingLinearClassifier(layers, ps))
285
return model if init is None else model.apply(init)
286
287
def text_classifier_learner(data:DataBunch, arch:Callable, bptt:int=70, max_len:int=70*20, config:dict=None,
288
pretrained:bool=True, drop_mult:float=1., lin_ftrs:Collection[int]=None,
289
ps:Collection[float]=None, **learn_kwargs) -> 'TextClassifierLearner':
290
"Create a `Learner` with a text classifier from `data` and `arch`."
291
model = get_text_classifier(arch, len(data.vocab.itos), data.c, bptt=bptt, max_len=max_len,
292
config=config, drop_mult=drop_mult, lin_ftrs=lin_ftrs, ps=ps)
293
meta = _model_meta[arch]
294
learn = RNNLearner(data, model, split_func=meta['split_clas'], **learn_kwargs)
295
if pretrained:
296
if 'url' not in meta:
297
warn("There are no pretrained weights for that architecture yet!")
298
return learn
299
model_path = untar_data(meta['url'], data=False)
300
fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]
301
learn.load_pretrained(*fnames, strict=False)
302
learn.freeze()
303
return learn
304
305