Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/text/models/awd_lstm.py
841 views
1
from ...torch_core import *
2
from ...layers import *
3
from ...train import ClassificationInterpretation
4
from ...basic_train import *
5
from ...basic_data import *
6
from ..data import TextClasDataBunch
7
import matplotlib.cm as cm
8
9
__all__ = ['EmbeddingDropout', 'LinearDecoder', 'AWD_LSTM', 'RNNDropout',
10
'SequentialRNN', 'WeightDropout', 'dropout_mask', 'awd_lstm_lm_split', 'awd_lstm_clas_split',
11
'awd_lstm_lm_config', 'awd_lstm_clas_config', 'TextClassificationInterpretation']
12
13
def dropout_mask(x:Tensor, sz:Collection[int], p:float):
14
"Return a dropout mask of the same type as `x`, size `sz`, with probability `p` to cancel an element."
15
return x.new(*sz).bernoulli_(1-p).div_(1-p)
16
17
class RNNDropout(Module):
18
"Dropout with probability `p` that is consistent on the seq_len dimension."
19
20
def __init__(self, p:float=0.5): self.p=p
21
22
def forward(self, x:Tensor)->Tensor:
23
if not self.training or self.p == 0.: return x
24
m = dropout_mask(x.data, (x.size(0), 1, x.size(2)), self.p)
25
return x * m
26
27
class WeightDropout(Module):
28
"A module that warps another layer in which some weights will be replaced by 0 during training."
29
30
def __init__(self, module:nn.Module, weight_p:float, layer_names:Collection[str]=['weight_hh_l0']):
31
self.module,self.weight_p,self.layer_names = module,weight_p,layer_names
32
for layer in self.layer_names:
33
#Makes a copy of the weights of the selected layers.
34
w = getattr(self.module, layer)
35
self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
36
self.module._parameters[layer] = F.dropout(w, p=self.weight_p, training=False)
37
38
def _setweights(self):
39
"Apply dropout to the raw weights."
40
for layer in self.layer_names:
41
raw_w = getattr(self, f'{layer}_raw')
42
self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=self.training)
43
44
def forward(self, *args:ArgStar):
45
self._setweights()
46
with warnings.catch_warnings():
47
#To avoid the warning that comes because the weights aren't flattened.
48
warnings.simplefilter("ignore")
49
return self.module.forward(*args)
50
51
def reset(self):
52
for layer in self.layer_names:
53
raw_w = getattr(self, f'{layer}_raw')
54
self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=False)
55
if hasattr(self.module, 'reset'): self.module.reset()
56
57
class EmbeddingDropout(Module):
58
"Apply dropout with probabily `embed_p` to an embedding layer `emb`."
59
60
def __init__(self, emb:nn.Module, embed_p:float):
61
self.emb,self.embed_p = emb,embed_p
62
self.pad_idx = self.emb.padding_idx
63
if self.pad_idx is None: self.pad_idx = -1
64
65
def forward(self, words:LongTensor, scale:Optional[float]=None)->Tensor:
66
if self.training and self.embed_p != 0:
67
size = (self.emb.weight.size(0),1)
68
mask = dropout_mask(self.emb.weight.data, size, self.embed_p)
69
masked_embed = self.emb.weight * mask
70
else: masked_embed = self.emb.weight
71
if scale: masked_embed.mul_(scale)
72
return F.embedding(words, masked_embed, self.pad_idx, self.emb.max_norm,
73
self.emb.norm_type, self.emb.scale_grad_by_freq, self.emb.sparse)
74
75
class AWD_LSTM(Module):
76
"AWD-LSTM/QRNN inspired by https://arxiv.org/abs/1708.02182."
77
78
initrange=0.1
79
80
def __init__(self, vocab_sz:int, emb_sz:int, n_hid:int, n_layers:int, pad_token:int=1, hidden_p:float=0.2,
81
input_p:float=0.6, embed_p:float=0.1, weight_p:float=0.5, qrnn:bool=False, bidir:bool=False):
82
self.bs,self.qrnn,self.emb_sz,self.n_hid,self.n_layers = 1,qrnn,emb_sz,n_hid,n_layers
83
self.n_dir = 2 if bidir else 1
84
self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token)
85
self.encoder_dp = EmbeddingDropout(self.encoder, embed_p)
86
if self.qrnn:
87
#Using QRNN requires an installation of cuda
88
from .qrnn import QRNN
89
self.rnns = [QRNN(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.n_dir, 1,
90
save_prev_x=True, zoneout=0, window=2 if l == 0 else 1, output_gate=True, bidirectional=bidir)
91
for l in range(n_layers)]
92
for rnn in self.rnns:
93
rnn.layers[0].linear = WeightDropout(rnn.layers[0].linear, weight_p, layer_names=['weight'])
94
else:
95
self.rnns = [nn.LSTM(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.n_dir, 1,
96
batch_first=True, bidirectional=bidir) for l in range(n_layers)]
97
self.rnns = [WeightDropout(rnn, weight_p) for rnn in self.rnns]
98
self.rnns = nn.ModuleList(self.rnns)
99
self.encoder.weight.data.uniform_(-self.initrange, self.initrange)
100
self.input_dp = RNNDropout(input_p)
101
self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)])
102
103
def forward(self, input:Tensor, from_embeddings:bool=False)->Tuple[Tensor,Tensor]:
104
if from_embeddings: bs,sl,es = input.size()
105
else: bs,sl = input.size()
106
if bs!=self.bs:
107
self.bs=bs
108
self.reset()
109
raw_output = self.input_dp(input if from_embeddings else self.encoder_dp(input))
110
new_hidden,raw_outputs,outputs = [],[],[]
111
for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)):
112
raw_output, new_h = rnn(raw_output, self.hidden[l])
113
new_hidden.append(new_h)
114
raw_outputs.append(raw_output)
115
if l != self.n_layers - 1: raw_output = hid_dp(raw_output)
116
outputs.append(raw_output)
117
self.hidden = to_detach(new_hidden, cpu=False)
118
return raw_outputs, outputs
119
120
def _one_hidden(self, l:int)->Tensor:
121
"Return one hidden state."
122
nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz) // self.n_dir
123
return one_param(self).new(self.n_dir, self.bs, nh).zero_()
124
125
def select_hidden(self, idxs):
126
if self.qrnn: self.hidden = [h[:,idxs,:] for h in self.hidden]
127
else: self.hidden = [(h[0][:,idxs,:],h[1][:,idxs,:]) for h in self.hidden]
128
self.bs = len(idxs)
129
130
def reset(self):
131
"Reset the hidden states."
132
[r.reset() for r in self.rnns if hasattr(r, 'reset')]
133
if self.qrnn: self.hidden = [self._one_hidden(l) for l in range(self.n_layers)]
134
else: self.hidden = [(self._one_hidden(l), self._one_hidden(l)) for l in range(self.n_layers)]
135
136
class LinearDecoder(Module):
137
"To go on top of a RNNCore module and create a Language Model."
138
initrange=0.1
139
140
def __init__(self, n_out:int, n_hid:int, output_p:float, tie_encoder:nn.Module=None, bias:bool=True):
141
self.decoder = nn.Linear(n_hid, n_out, bias=bias)
142
self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
143
self.output_dp = RNNDropout(output_p)
144
if bias: self.decoder.bias.data.zero_()
145
if tie_encoder: self.decoder.weight = tie_encoder.weight
146
147
def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
148
raw_outputs, outputs = input
149
output = self.output_dp(outputs[-1])
150
decoded = self.decoder(output)
151
return decoded, raw_outputs, outputs
152
153
class SequentialRNN(nn.Sequential):
154
"A sequential module that passes the reset call to its children."
155
def reset(self):
156
for c in self.children():
157
if hasattr(c, 'reset'): c.reset()
158
159
def awd_lstm_lm_split(model:nn.Module) -> List[nn.Module]:
160
"Split a RNN `model` in groups for differential learning rates."
161
groups = [[rnn, dp] for rnn, dp in zip(model[0].rnns, model[0].hidden_dps)]
162
return groups + [[model[0].encoder, model[0].encoder_dp, model[1]]]
163
164
def awd_lstm_clas_split(model:nn.Module) -> List[nn.Module]:
165
"Split a RNN `model` in groups for differential learning rates."
166
groups = [[model[0].module.encoder, model[0].module.encoder_dp]]
167
groups += [[rnn, dp] for rnn, dp in zip(model[0].module.rnns, model[0].module.hidden_dps)]
168
return groups + [[model[1]]]
169
170
awd_lstm_lm_config = dict(emb_sz=400, n_hid=1152, n_layers=3, pad_token=1, qrnn=False, bidir=False, output_p=0.1,
171
hidden_p=0.15, input_p=0.25, embed_p=0.02, weight_p=0.2, tie_weights=True, out_bias=True)
172
173
awd_lstm_clas_config = dict(emb_sz=400, n_hid=1152, n_layers=3, pad_token=1, qrnn=False, bidir=False, output_p=0.4,
174
hidden_p=0.3, input_p=0.4, embed_p=0.05, weight_p=0.5)
175
176
def value2rgba(x:float, cmap:Callable=cm.RdYlGn, alpha_mult:float=1.0)->Tuple:
177
"Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
178
c = cmap(x)
179
rgb = (np.array(c[:-1]) * 255).astype(int)
180
a = c[-1] * alpha_mult
181
return tuple(rgb.tolist() + [a])
182
183
def piece_attn_html(pieces:List[str], attns:List[float], sep:str=' ', **kwargs)->str:
184
html_code,spans = ['<span style="font-family: monospace;">'], []
185
for p, a in zip(pieces, attns):
186
p = html.escape(p)
187
c = str(value2rgba(a, alpha_mult=0.5, **kwargs))
188
spans.append(f'<span title="{a:.3f}" style="background-color: rgba{c};">{p}</span>')
189
html_code.append(sep.join(spans))
190
html_code.append('</span>')
191
return ''.join(html_code)
192
193
def show_piece_attn(*args, **kwargs):
194
from IPython.display import display, HTML
195
display(HTML(piece_attn_html(*args, **kwargs)))
196
197
def _eval_dropouts(mod):
198
module_name = mod.__class__.__name__
199
if 'Dropout' in module_name or 'BatchNorm' in module_name: mod.training = False
200
for module in mod.children(): _eval_dropouts(module)
201
202
class TextClassificationInterpretation(ClassificationInterpretation):
203
"""Provides an interpretation of classification based on input sensitivity.
204
This was designed for AWD-LSTM only for the moment, because Transformer already has its own attentional model.
205
"""
206
207
def __init__(self, learn: Learner, preds: Tensor, y_true: Tensor, losses: Tensor, ds_type: DatasetType = DatasetType.Valid):
208
super().__init__(learn,preds,y_true,losses,ds_type)
209
self.model = learn.model
210
211
def intrinsic_attention(self, text:str, class_id:int=None):
212
"""Calculate the intrinsic attention of the input w.r.t to an output `class_id`, or the classification given by the model if `None`.
213
For reference, see the Sequential Jacobian session at https://www.cs.toronto.edu/~graves/preprint.pdf
214
"""
215
self.model.train()
216
_eval_dropouts(self.model)
217
self.model.zero_grad()
218
self.model.reset()
219
ids = self.data.one_item(text)[0]
220
emb = self.model[0].module.encoder(ids).detach().requires_grad_(True)
221
lstm_output = self.model[0].module(emb, from_embeddings=True)
222
self.model.eval()
223
cl = self.model[1](lstm_output + (torch.zeros_like(ids).byte(),))[0].softmax(dim=-1)
224
if class_id is None: class_id = cl.argmax()
225
cl[0][class_id].backward()
226
attn = emb.grad.squeeze().abs().sum(dim=-1)
227
attn /= attn.max()
228
tokens = self.data.single_ds.reconstruct(ids[0])
229
return tokens, attn
230
231
def html_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->str:
232
text, attn = self.intrinsic_attention(text, class_id)
233
return piece_attn_html(text.text.split(), to_np(attn), **kwargs)
234
235
def show_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->None:
236
text, attn = self.intrinsic_attention(text, class_id)
237
show_piece_attn(text.text.split(), to_np(attn), **kwargs)
238
239
def show_top_losses(self, k:int, max_len:int=70)->None:
240
"""
241
Create a tabulation showing the first `k` texts in top_losses along with their prediction, actual,loss, and probability of
242
actual class. `max_len` is the maximum number of tokens displayed.
243
"""
244
from IPython.display import display, HTML
245
items = []
246
tl_val,tl_idx = self.top_losses()
247
for i,idx in enumerate(tl_idx):
248
if k <= 0: break
249
k -= 1
250
tx,cl = self.data.dl(self.ds_type).dataset[idx]
251
cl = cl.data
252
classes = self.data.classes
253
txt = ' '.join(tx.text.split(' ')[:max_len]) if max_len is not None else tx.text
254
tmp = [txt, f'{classes[self.pred_class[idx]]}', f'{classes[cl]}', f'{self.losses[idx]:.2f}',
255
f'{self.preds[idx][cl]:.2f}']
256
items.append(tmp)
257
items = np.array(items)
258
names = ['Text', 'Prediction', 'Actual', 'Loss', 'Probability']
259
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
260
with pd.option_context('display.max_colwidth', -1):
261
display(HTML(df.to_html(index=False)))
262
263