Please find jax implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/15/nmt_jax.ipynb
Neural machine translation using encoder-decoder RNN
We show how to implement NMT using an encoder-decoder.
Based on sec 9.7 of http://d2l.ai/chapter_recurrent-modern/seq2seq.html
Required functions for text preprocessing
For more details on this functions: See this colab for details.
Data
We use a english-french dataset. See this colab for details.
Encoder-decoder
Abstract base class
Encoder
We use a 2-level GRU for the encoder; we set the context as the final state of the GRU. The input to the GRU is the word embedding of each token.
Decoder
We use another GRU as the decoder. The initial state is the final state of the encoder, so we must use the same number of hidden units. In addition, we pass in the context (ie final state of encoder) as input to every step of the decoder.
Loss function
We use cross entropy loss, but we must mask out target tokens that are just padding. We replace all outputs beyond the valid length to the target value of 0.
We now use this to create a weight mask of 0s and 1s, where 0 corresponds to invalid locations. When we compute the cross entropy loss, we multiply by this weight mask, thus ignoring invalid locations.
As an example, let us create a prediction tensor of all ones of size (3,4,10) and a target label tensor of all ones of size (3,4). We specify the valud lengths to (4,2,0). Thus the first loss should be twice the second. And the third loss should be 0.
Training
We use teacher forcing, where the inputs to the decoder are "bos" (beginning of sentence), followed by the ground truth target tokens from the previous step, as shown below.
Prediction
We use greedy decoding, where the inputs to the decoder are "bos" (beginning of sentence), followed by the most likely target token from the previous step, as shown below. We keep decoding until the model generates "eos" (end of sentence).
Evaluation
In the MT community, the standard evaluation metric is known as BLEU (Bilingual Evaluation Understudy), which measures how many n-grams in the predicted target match the true label target.
For example, suppose the prediction is A,B,B,C,D and the target is A,B,C,D,E,F. There are five 1-grams in the prediction, of which 4 find a match in the target (the second "B" is a "false positive"), so the precision for 1-grams is . Similarly, there are four 2-grams, of which 3 find a match (the bigram "BB" does not occur), so . We continue in this way to compute up to , where is the max n-gram length. (Since we are using words, not characters, we typically keep small, to avoid sparse counts.)
The BLEU score is then defined by where is the length of the target label sequence, and is the length of the prediction.
Since predicting shorter sequences tends to give higher values, short sequences are penalized by the exponential factor. For example, suppose and the label sequence is A,B,C,D,E,F. If the predicted sequence is A,B,B,C,D, we have and , and the penalty factor is . If the predicted sequence is A,B, we have , but the penalty factor is .