Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/a4/nmt_model.py
995 views
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
4
"""
5
CS224N 2022-23: Homework 4
6
nmt_model.py: NMT Model
7
Pencheng Yin <[email protected]>
8
Sahil Chopra <[email protected]>
9
Vera Lin <[email protected]>
10
Siyan Li <[email protected]>
11
"""
12
from collections import namedtuple
13
import sys
14
from typing import List, Tuple, Dict, Set, Union
15
import torch
16
import torch.nn as nn
17
import torch.nn.utils
18
import torch.nn.functional as F
19
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
20
21
from model_embeddings import ModelEmbeddings
22
23
Hypothesis = namedtuple('Hypothesis', ['value', 'score'])
24
25
26
class NMT(nn.Module):
27
""" Simple Neural Machine Translation Model:
28
- Bidrectional LSTM Encoder
29
- Unidirection LSTM Decoder
30
- Global Attention Model (Luong, et al. 2015)
31
"""
32
33
def __init__(self, embed_size, hidden_size, vocab, dropout_rate=0.2):
34
""" Init NMT Model.
35
36
@param embed_size (int): Embedding size (dimensionality)
37
@param hidden_size (int): Hidden Size, the size of hidden states (dimensionality)
38
@param vocab (Vocab): Vocabulary object containing src and tgt languages
39
See vocab.py for documentation.
40
@param dropout_rate (float): Dropout probability, for attention
41
"""
42
super(NMT, self).__init__()
43
self.model_embeddings = ModelEmbeddings(embed_size, vocab)
44
self.hidden_size = hidden_size
45
self.dropout_rate = dropout_rate
46
self.vocab = vocab
47
#self.device = torch.device('cpu')
48
# default values
49
self.encoder = None
50
self.decoder = None
51
self.h_projection = None
52
self.c_projection = None
53
self.att_projection = None
54
self.combined_output_projection = None
55
self.target_vocab_projection = None
56
self.dropout = None
57
# For sanity check only, not relevant to implementation
58
self.gen_sanity_check = False
59
self.counter = 0
60
61
### YOUR CODE HERE (~9 Lines)
62
### TODO - Initialize the following variables IN THIS ORDER:
63
### self.post_embed_cnn (Conv1d layer with kernel size 2, input and output channels = embed_size,
64
### padding = same to preserve output shape )
65
### self.encoder (Bidirectional LSTM with bias)
66
### self.decoder (LSTM Cell with bias)
67
### self.h_projection (Linear Layer with no bias), called W_{h} in the PDF.
68
### self.c_projection (Linear Layer with no bias), called W_{c} in the PDF.
69
### self.att_projection (Linear Layer with no bias), called W_{attProj} in the PDF.
70
### self.combined_output_projection (Linear Layer with no bias), called W_{u} in the PDF.
71
### self.target_vocab_projection (Linear Layer with no bias), called W_{vocab} in the PDF.
72
### self.dropout (Dropout Layer)
73
###
74
### Use the following docs to properly initialize these variables:
75
### LSTM:
76
### https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM
77
### LSTM Cell:
78
### https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell
79
### Linear Layer:
80
### https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
81
### Dropout Layer:
82
### https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout
83
### Conv1D Layer:
84
### https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
85
self.post_embed_cnn = nn.Conv1d(embed_size, embed_size, kernel_size=2, padding="same")
86
self.encoder = nn.LSTM(input_size=embed_size, hidden_size=self.hidden_size, bidirectional=True)
87
self.decoder = nn.LSTMCell(input_size=(embed_size + self.hidden_size), hidden_size=self.hidden_size)
88
self.h_projection = nn.Linear(2*self.hidden_size, self.hidden_size, bias=False)
89
self.c_projection = nn.Linear(2*self.hidden_size, self.hidden_size, bias=False)
90
self.att_projection = nn.Linear(2*self.hidden_size, self.hidden_size, bias=False)
91
self.combined_output_projection = nn.Linear(3*self.hidden_size, self.hidden_size, bias=False)
92
self.target_vocab_projection = nn.Linear(len(self.vocab.tgt), self.hidden_size)
93
self.dropout = nn.Dropout(self.dropout_rate)
94
### END YOUR CODE
95
96
def forward(self, source: List[List[str]], target: List[List[str]]) -> torch.Tensor:
97
""" Take a mini-batch of source and target sentences, compute the log-likelihood of
98
target sentences under the language models learned by the NMT system.
99
100
@param source (List[List[str]]): list of source sentence tokens
101
@param target (List[List[str]]): list of target sentence tokens, wrapped by `<s>` and `</s>`
102
103
@returns scores (Tensor): a variable/tensor of shape (b, ) representing the
104
log-likelihood of generating the gold-standard target sentence for
105
each example in the input batch. Here b = batch size.
106
"""
107
# Compute sentence lengths
108
source_lengths = [len(s) for s in source]
109
110
# Convert list of lists into tensors
111
source_padded = self.vocab.src.to_input_tensor(source, device=self.device) # Tensor: (src_len, b)
112
target_padded = self.vocab.tgt.to_input_tensor(target, device=self.device) # Tensor: (tgt_len, b)
113
114
### Run the network forward:
115
### 1. Apply the encoder to `source_padded` by calling `self.encode()`
116
### 2. Generate sentence masks for `source_padded` by calling `self.generate_sent_masks()`
117
### 3. Apply the decoder to compute combined-output by calling `self.decode()`
118
### 4. Compute log probability distribution over the target vocabulary using the
119
### combined_outputs returned by the `self.decode()` function.
120
121
enc_hiddens, dec_init_state = self.encode(source_padded, source_lengths)
122
enc_masks = self.generate_sent_masks(enc_hiddens, source_lengths)
123
combined_outputs = self.decode(enc_hiddens, enc_masks, dec_init_state, target_padded)
124
P = F.log_softmax(self.target_vocab_projection(combined_outputs), dim=-1)
125
126
# Zero out, probabilities for which we have nothing in the target text
127
target_masks = (target_padded != self.vocab.tgt['<pad>']).float()
128
129
# Compute log probability of generating true target words
130
target_gold_words_log_prob = torch.gather(P, index=target_padded[1:].unsqueeze(-1), dim=-1).squeeze(
131
-1) * target_masks[1:]
132
scores = target_gold_words_log_prob.sum(dim=0)
133
return scores
134
135
def encode(self, source_padded: torch.Tensor, source_lengths: List[int]) -> Tuple[
136
torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
137
""" Apply the encoder to source sentences to obtain encoder hidden states.
138
Additionally, take the final states of the encoder and project them to obtain initial states for decoder.
139
140
@param source_padded (Tensor): Tensor of padded source sentences with shape (src_len, b), where
141
b = batch_size, src_len = maximum source sentence length. Note that
142
these have already been sorted in order of longest to shortest sentence.
143
@param source_lengths (List[int]): List of actual lengths for each of the source sentences in the batch
144
@returns enc_hiddens (Tensor): Tensor of hidden units with shape (b, src_len, h*2), where
145
b = batch size, src_len = maximum source sentence length, h = hidden size.
146
@returns dec_init_state (tuple(Tensor, Tensor)): Tuple of tensors representing the decoder's initial
147
hidden state and cell. Both tensors should have shape (2, b, h).
148
"""
149
enc_hiddens, dec_init_state = None, None
150
151
### YOUR CODE HERE (~ 11 Lines)
152
### TODO:
153
### 1. Construct Tensor `X` of source sentences with shape (src_len, b, e) using the source model embeddings.
154
### src_len = maximum source sentence length, b = batch size, e = embedding size. Note
155
### that there is no initial hidden state or cell for the encoder.
156
### 2. Apply the post_embed_cnn layer. Before feeding X into the CNN, first use torch.permute to change the
157
### shape of X to (b, e, src_len). After getting the output from the CNN, still stored in the X variable,
158
### remember to use torch.permute again to revert X back to its original shape.
159
### 3. Compute `enc_hiddens`, `last_hidden`, `last_cell` by applying the encoder to `X`.
160
### - Before you can apply the encoder, you need to apply the `pack_padded_sequence` function to X.
161
### - After you apply the encoder, you need to apply the `pad_packed_sequence` function to enc_hiddens.
162
### - Note that the shape of the tensor output returned by the encoder RNN is (src_len, b, h*2) and we want to
163
### return a tensor of shape (b, src_len, h*2) as `enc_hiddens`, so you may need to do more permuting.
164
### - Note on using pad_packed_sequence -> For batched inputs, you need to make sure that each of the
165
### individual input examples has the same shape.
166
### 4. Compute `dec_init_state` = (init_decoder_hidden, init_decoder_cell):
167
### - `init_decoder_hidden`:
168
### `last_hidden` is a tensor shape (2, b, h). The first dimension corresponds to forwards and backwards.
169
### Concatenate the forwards and backwards tensors to obtain a tensor shape (b, 2*h).
170
### Apply the h_projection layer to this in order to compute init_decoder_hidden.
171
### This is h_0^{dec} in the PDF. Here b = batch size, h = hidden size
172
### - `init_decoder_cell`:
173
### `last_cell` is a tensor shape (2, b, h). The first dimension corresponds to forwards and backwards.
174
### Concatenate the forwards and backwards tensors to obtain a tensor shape (b, 2*h).
175
### Apply the c_projection layer to this in order to compute init_decoder_cell.
176
### This is c_0^{dec} in the PDF. Here b = batch size, h = hidden size
177
###
178
### See the following docs, as you may need to use some of the following functions in your implementation:
179
### Pack the padded sequence X before passing to the encoder:
180
### https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html
181
### Pad the packed sequence, enc_hiddens, returned by the encoder:
182
### https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_packed_sequence.html
183
### Tensor Concatenation:
184
### https://pytorch.org/docs/stable/generated/torch.cat.html
185
### Tensor Permute:
186
### https://pytorch.org/docs/stable/generated/torch.permute.html
187
### Tensor Reshape (a possible alternative to permute):
188
### https://pytorch.org/docs/stable/generated/torch.Tensor.reshape.html
189
_, B = source_padded.shape
190
X = self.model_embeddings.source(source_padded) #(src_len, b, e)
191
x = self.post_embed_cnn(torch.permute(X, (1, 2, 0))).permute((2, 0, 1)) #(b, e, src_len) -> (src_len, b, e)
192
enc_hiddens, (last_hidden, last_cell) = self.encoder(pack_padded_sequence(x, source_lengths))#(src_len, b, h*2)
193
enc_hiddens, _ = pad_packed_sequence(enc_hiddens)
194
enc_hiddens = enc_hiddens.permute((1, 0, 2))
195
init_decoder_hidden = self.h_projection(last_hidden.permute((1, 0, 2)).reshape(B, -1))
196
init_decoder_cell = self.c_projection(last_cell.permute((1, 0, 2)).reshape(B, -1))
197
dec_init_state = (init_decoder_hidden, init_decoder_cell) #(b,h) & (b,h)
198
### END YOUR CODE
199
200
return enc_hiddens, dec_init_state
201
202
def decode(self, enc_hiddens: torch.Tensor, enc_masks: torch.Tensor,
203
dec_init_state: Tuple[torch.Tensor, torch.Tensor], target_padded: torch.Tensor) -> torch.Tensor:
204
"""Compute combined output vectors for a batch.
205
206
@param enc_hiddens (Tensor): Hidden states (b, src_len, h*2), where
207
b = batch size, src_len = maximum source sentence length, h = hidden size.
208
@param enc_masks (Tensor): Tensor of sentence masks (b, src_len), where
209
b = batch size, src_len = maximum source sentence length.
210
@param dec_init_state (tuple(Tensor, Tensor)): Initial state and cell for decoder
211
@param target_padded (Tensor): Gold-standard padded target sentences (tgt_len, b), where
212
tgt_len = maximum target sentence length, b = batch size.
213
214
@returns combined_outputs (Tensor): combined output tensor (tgt_len, b, h), where
215
tgt_len = maximum target sentence length, b = batch_size, h = hidden size
216
"""
217
# Chop off the <END> token for max length sentences.
218
target_padded = target_padded[:-1]
219
220
# Initialize the decoder state (hidden and cell)
221
dec_state = dec_init_state
222
223
# Initialize previous combined output vector o_{t-1} as zero
224
batch_size = enc_hiddens.size(0)
225
o_prev = torch.zeros(batch_size, self.hidden_size, device=self.device)
226
227
# Initialize a list we will use to collect the combined output o_t on each step
228
combined_outputs = []
229
230
### YOUR CODE HERE (~9 Lines)
231
### TODO:
232
### 1. Apply the attention projection layer to `enc_hiddens` to obtain `enc_hiddens_proj`,
233
### which should be shape (b, src_len, h),
234
### where b = batch size, src_len = maximum source length, h = hidden size.
235
### This is applying W_{attProj} to h^enc, as described in the PDF.
236
### 2. Construct tensor `Y` of target sentences with shape (tgt_len, b, e) using the target model embeddings.
237
### where tgt_len = maximum target sentence length, b = batch size, e = embedding size.
238
### 3. Use the torch.split function to iterate over the time dimension of Y.
239
### Within the loop, this will give you Y_t of shape (1, b, e) where b = batch size, e = embedding size.
240
### - Squeeze Y_t into a tensor of dimension (b, e).
241
### - Construct Ybar_t by concatenating Y_t with o_prev on their last dimension
242
### - Use the step function to compute the the Decoder's next (cell, state) values
243
### as well as the new combined output o_t.
244
### - Append o_t to combined_outputs
245
### - Update o_prev to the new o_t.
246
### 4. Use torch.stack to convert combined_outputs from a list length tgt_len of
247
### tensors shape (b, h), to a single tensor shape (tgt_len, b, h)
248
### where tgt_len = maximum target sentence length, b = batch size, h = hidden size.
249
###
250
### Note:
251
### - When using the squeeze() function make sure to specify the dimension you want to squeeze
252
### over. Otherwise, you will remove the batch dimension accidentally, if batch_size = 1.
253
###
254
### You may find some of these functions useful:
255
### Zeros Tensor:
256
### https://pytorch.org/docs/stable/torch.html#torch.zeros
257
### Tensor Splitting (iteration):
258
### https://pytorch.org/docs/stable/torch.html#torch.split
259
### Tensor Dimension Squeezing:
260
### https://pytorch.org/docs/stable/torch.html#torch.squeeze
261
### Tensor Concatenation:
262
### https://pytorch.org/docs/stable/torch.html#torch.cat
263
### Tensor Stacking:
264
### https://pytorch.org/docs/stable/torch.html#torch.stack
265
enc_hiddens_proj = self.att_projection(enc_hiddens) #(b, src_len, h)
266
Y = self.model_embeddings.target(target_padded) #(tgt_len, b, e)
267
for y in torch.split(Y, 1):
268
Ybar_t = torch.cat([y.squeeze(), o_prev], dim=1) # (b, e+h)
269
dec_state, combined_output, _ = self.step(Ybar_t, dec_state, enc_hiddens, enc_hiddens_proj, enc_masks)
270
combined_outputs.append(combined_output)
271
combined_outputs = torch.stack(combined_outputs)
272
### END YOUR CODE
273
274
return combined_outputs
275
276
def step(self, Ybar_t: torch.Tensor,
277
dec_state: Tuple[torch.Tensor, torch.Tensor],
278
enc_hiddens: torch.Tensor,
279
enc_hiddens_proj: torch.Tensor,
280
enc_masks: torch.Tensor) -> Tuple[Tuple, torch.Tensor, torch.Tensor]:
281
""" Compute one forward step of the LSTM decoder, including the attention computation.
282
283
@param Ybar_t (Tensor): Concatenated Tensor of [Y_t o_prev], with shape (b, e + h). The input for the decoder,
284
where b = batch size, e = embedding size, h = hidden size.
285
@param dec_state (tuple(Tensor, Tensor)): Tuple of tensors both with shape (b, h), where b = batch size, h = hidden size.
286
First tensor is decoder's prev hidden state, second tensor is decoder's prev cell.
287
@param enc_hiddens (Tensor): Encoder hidden states Tensor, with shape (b, src_len, h * 2), where b = batch size,
288
src_len = maximum source length, h = hidden size.
289
@param enc_hiddens_proj (Tensor): Encoder hidden states Tensor, projected from (h * 2) to h. Tensor is with shape (b, src_len, h),
290
where b = batch size, src_len = maximum source length, h = hidden size.
291
@param enc_masks (Tensor): Tensor of sentence masks shape (b, src_len),
292
where b = batch size, src_len is maximum source length.
293
294
@returns dec_state (tuple (Tensor, Tensor)): Tuple of tensors both shape (b, h), where b = batch size, h = hidden size.
295
First tensor is decoder's new hidden state, second tensor is decoder's new cell.
296
@returns combined_output (Tensor): Combined output Tensor at timestep t, shape (b, h), where b = batch size, h = hidden size.
297
@returns e_t (Tensor): Tensor of shape (b, src_len). It is attention scores distribution.
298
Note: You will not use this outside of this function.
299
We are simply returning this value so that we can sanity check
300
your implementation.
301
"""
302
303
combined_output = None
304
305
### YOUR CODE HERE (~3 Lines)
306
### TODO:
307
### 1. Apply the decoder to `Ybar_t` and `dec_state`to obtain the new dec_state.
308
### 2. Split dec_state into its two parts (dec_hidden, dec_cell)
309
### 3. Compute the attention scores e_t, a Tensor shape (b, src_len).
310
### Note: b = batch_size, src_len = maximum source length, h = hidden size.
311
###
312
### Hints:
313
### - dec_hidden is shape (b, h) and corresponds to h^dec_t in the PDF (batched)
314
### - enc_hiddens_proj is shape (b, src_len, h) and corresponds to W_{attProj} h^enc (batched).
315
### - Use batched matrix multiplication (torch.bmm) to compute e_t (be careful about the input/ output shapes!)
316
### - To get the tensors into the right shapes for bmm, you will need to do some squeezing and unsqueezing.
317
### - When using the squeeze() function make sure to specify the dimension you want to squeeze
318
### over. Otherwise, you will remove the batch dimension accidentally, if batch_size = 1.
319
###
320
### Use the following docs to implement this functionality:
321
### Batch Multiplication:
322
### https://pytorch.org/docs/stable/torch.html#torch.bmm
323
### Tensor Unsqueeze:
324
### https://pytorch.org/docs/stable/torch.html#torch.unsqueeze
325
### Tensor Squeeze:
326
### https://pytorch.org/docs/stable/torch.html#torch.squeeze
327
328
dec_state = self.decoder(Ybar_t, dec_state)
329
dec_hidden = dec_state[0]
330
e_t = torch.bmm(enc_hiddens_proj, dec_hidden.unsqueeze(-1)).squeeze(-1) #(b, src_len, h) (b, h, 1) -> (b, src_len)
331
### END YOUR CODE
332
333
# Set e_t to -inf where enc_masks has 1
334
if enc_masks is not None:
335
e_t.data.masked_fill_(enc_masks.bool(), -float('inf'))
336
337
### YOUR CODE HERE (~6 Lines)
338
### TODO:
339
### 1. Apply softmax to e_t to yield alpha_t
340
### 2. Use batched matrix multiplication between alpha_t and enc_hiddens to obtain the
341
### attention output vector, a_t.
342
# $$ Hints:
343
### - alpha_t is shape (b, src_len)
344
### - enc_hiddens is shape (b, src_len, 2h)
345
### - a_t should be shape (b, 2h)
346
### - You will need to do some squeezing and unsqueezing.
347
### Note: b = batch size, src_len = maximum source length, h = hidden size.
348
###
349
### 3. Concatenate dec_hidden with a_t to compute tensor U_t
350
### 4. Apply the combined output projection layer to U_t to compute tensor V_t
351
### 5. Compute tensor O_t by first applying the Tanh function and then the dropout layer.
352
###
353
### Use the following docs to implement this functionality:
354
### Softmax:
355
### https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.softmax
356
### Batch Multiplication:
357
### https://pytorch.org/docs/stable/torch.html#torch.bmm
358
### Tensor View:
359
### https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view
360
### Tensor Concatenation:
361
### https://pytorch.org/docs/stable/torch.html#torch.cat
362
### Tanh:
363
### https://pytorch.org/docs/stable/torch.html#torch.tanh
364
365
alpha_t = F.softmax(e_t)
366
a_t = torch.bmm(alpha_t.unsqueeze(1), enc_hiddens).squeeze(1)
367
U_t = torch.cat([dec_hidden, a_t], dim=1) #(b, 3h)
368
V_t = self.combined_output_projection(U_t)
369
O_t = self.dropout(torch.tanh(V_t))
370
### END YOUR CODE
371
372
combined_output = O_t
373
return dec_state, combined_output, e_t
374
375
def generate_sent_masks(self, enc_hiddens: torch.Tensor, source_lengths: List[int]) -> torch.Tensor:
376
""" Generate sentence masks for encoder hidden states.
377
378
@param enc_hiddens (Tensor): encodings of shape (b, src_len, 2*h), where b = batch size,
379
src_len = max source length, h = hidden size.
380
@param source_lengths (List[int]): List of actual lengths for each of the sentences in the batch.
381
382
@returns enc_masks (Tensor): Tensor of sentence masks of shape (b, src_len),
383
where src_len = max source length, h = hidden size.
384
"""
385
enc_masks = torch.zeros(enc_hiddens.size(0), enc_hiddens.size(1), dtype=torch.float)
386
for e_id, src_len in enumerate(source_lengths):
387
enc_masks[e_id, src_len:] = 1
388
return enc_masks.to(self.device)
389
390
def beam_search(self, src_sent: List[str], beam_size: int = 5, max_decoding_time_step: int = 70) -> List[
391
Hypothesis]:
392
""" Given a single source sentence, perform beam search, yielding translations in the target language.
393
@param src_sent (List[str]): a single source sentence (words)
394
@param beam_size (int): beam size
395
@param max_decoding_time_step (int): maximum number of time steps to unroll the decoding RNN
396
@returns hypotheses (List[Hypothesis]): a list of hypothesis, each hypothesis has two fields:
397
value: List[str]: the decoded target sentence, represented as a list of words
398
score: float: the log-likelihood of the target sentence
399
"""
400
src_sents_var = self.vocab.src.to_input_tensor([src_sent], self.device)
401
402
src_encodings, dec_init_vec = self.encode(src_sents_var, [len(src_sent)])
403
src_encodings_att_linear = self.att_projection(src_encodings)
404
405
h_tm1 = dec_init_vec
406
att_tm1 = torch.zeros(1, self.hidden_size, device=self.device)
407
408
eos_id = self.vocab.tgt['</s>']
409
410
hypotheses = [['<s>']]
411
hyp_scores = torch.zeros(len(hypotheses), dtype=torch.float, device=self.device)
412
completed_hypotheses = []
413
414
t = 0
415
while len(completed_hypotheses) < beam_size and t < max_decoding_time_step:
416
t += 1
417
hyp_num = len(hypotheses)
418
419
exp_src_encodings = src_encodings.expand(hyp_num,
420
src_encodings.size(1),
421
src_encodings.size(2))
422
423
exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num,
424
src_encodings_att_linear.size(1),
425
src_encodings_att_linear.size(2))
426
427
y_tm1 = torch.tensor([self.vocab.tgt[hyp[-1]] for hyp in hypotheses], dtype=torch.long, device=self.device)
428
y_t_embed = self.model_embeddings.target(y_tm1)
429
430
x = torch.cat([y_t_embed, att_tm1], dim=-1)
431
432
(h_t, cell_t), att_t, _ = self.step(x, h_tm1,
433
exp_src_encodings, exp_src_encodings_att_linear, enc_masks=None)
434
435
# log probabilities over target words
436
log_p_t = F.log_softmax(self.target_vocab_projection(att_t), dim=-1)
437
438
live_hyp_num = beam_size - len(completed_hypotheses)
439
contiuating_hyp_scores = (hyp_scores.unsqueeze(1).expand_as(log_p_t) + log_p_t).view(-1)
440
top_cand_hyp_scores, top_cand_hyp_pos = torch.topk(contiuating_hyp_scores, k=live_hyp_num)
441
442
prev_hyp_ids = torch.div(top_cand_hyp_pos, len(self.vocab.tgt), rounding_mode='floor')
443
hyp_word_ids = top_cand_hyp_pos % len(self.vocab.tgt)
444
445
new_hypotheses = []
446
live_hyp_ids = []
447
new_hyp_scores = []
448
449
for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip(prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores):
450
prev_hyp_id = prev_hyp_id.item()
451
hyp_word_id = hyp_word_id.item()
452
cand_new_hyp_score = cand_new_hyp_score.item()
453
454
hyp_word = self.vocab.tgt.id2word[hyp_word_id]
455
new_hyp_sent = hypotheses[prev_hyp_id] + [hyp_word]
456
if hyp_word == '</s>':
457
completed_hypotheses.append(Hypothesis(value=new_hyp_sent[1:-1],
458
score=cand_new_hyp_score))
459
else:
460
new_hypotheses.append(new_hyp_sent)
461
live_hyp_ids.append(prev_hyp_id)
462
new_hyp_scores.append(cand_new_hyp_score)
463
464
if len(completed_hypotheses) == beam_size:
465
break
466
467
live_hyp_ids = torch.tensor(live_hyp_ids, dtype=torch.long, device=self.device)
468
h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
469
att_tm1 = att_t[live_hyp_ids]
470
471
hypotheses = new_hypotheses
472
hyp_scores = torch.tensor(new_hyp_scores, dtype=torch.float, device=self.device)
473
474
if len(completed_hypotheses) == 0:
475
completed_hypotheses.append(Hypothesis(value=hypotheses[0][1:],
476
score=hyp_scores[0].item()))
477
478
completed_hypotheses.sort(key=lambda hyp: hyp.score, reverse=True)
479
480
return completed_hypotheses
481
482
@property
483
def device(self) -> torch.device:
484
""" Determine which device to place the Tensors upon, CPU or GPU.
485
"""
486
return self.model_embeddings.source.weight.device
487
488
@staticmethod
489
def load(model_path: str):
490
""" Load the model from a file.
491
@param model_path (str): path to model
492
"""
493
params = torch.load(model_path, map_location=lambda storage, loc: storage)
494
args = params['args']
495
model = NMT(vocab=params['vocab'], **args)
496
model.load_state_dict(params['state_dict'])
497
498
return model
499
500
def save(self, path: str):
501
""" Save the odel to a file.
502
@param path (str): path to the model
503
"""
504
print('save model parameters to [%s]' % path, file=sys.stderr)
505
506
params = {
507
'args': dict(embed_size=self.model_embeddings.embed_size, hidden_size=self.hidden_size,
508
dropout_rate=self.dropout_rate),
509
'vocab': self.vocab,
510
'state_dict': self.state_dict()
511
}
512
513
torch.save(params, path)
514
515