Please find jax implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/15/rnn_jax.ipynb
Recurrent neural networks
We show how to implement RNNs from scratch. Based on sec 8.5 of http://d2l.ai/chapter_recurrent-neural-networks/rnn-scratch.html.
Data
As data, we use the book "The Time Machine" by H G Wells, preprocessed using the code in this colab.
Model
We fit an unconditional RNN, for language modeling (ie. not vec2seq or seq2seq. Following the D2L notation, the model has the form where is the matrix of (one-hot) inputs (for batch size and vocabulary size ), is the matrix of hidden states (for hidden states), and is the matrix of output logits (for output labels, often ).
Prediction (generation)
We pass in an initial prefix string, that is not generated. This is used to "warm-up" the hidden state. Specifically, we update the hidden state given the observed prefix, but don't generate anything. After that, for each of the T steps, we compute the (1,V) output tensor, pick the argmax index, and append it to the output. Finally, we convert the to indices to readable token sequence of size (1,T). (Note that this is a greedy, deterministic procedure.)
(D2L calls this predict_ch8
since it occurs in their chapter 8.)
Training
To ensure the gradient doesn't blow up when doing backpropagation through many layers, we use gradient clipping, which corresponds to the update where is the scaling parameter, and is the gradient vector.
The training step is fairly standard, except for the use of gradient clipping, and the issue of the hidden state. If the data iterator uses random ordering of the sequences, we need to initialize the hidden state for each minibatch. However, if the data iterator uses sequential ordering, we only initialize the hidden state at the very beginning of the process. In the latter case, the hidden state will depend on the value at the previous minibatch. We detach the state vector to prevent gradients flowing across minibatch boundaries.
The state vector may be a tensor or a tuple, depending on what kind of RNN we are using. In addition, the parameter updater can be a built-in optimizer, or the simpler D2L sgd optimizer.
The main training function is fairly standard. The loss function is per-symbol cross-entropy, , where is the model prediction from the RNN. Since we compute the average loss across time steps within a batch, we are computing . The exponential of this is the perplexity (ppl). We plot this metric during training, since it is independent of document length. In addition, we print the MAP sequence prediction following the suffix 'time traveller', to get a sense of what the model is doing.
Creating a PyTorch module
We now show how to use create an RNN as a module, which is faster than our pure Python implementation.
First we create a single hidden chain to represent the state.
Now we update the state with a random one-hot tensor of inputs.
Now we make an RNN module.
Test the untrained model.
Train it. The results are similar to the 'from scratch' implementation, but much faster.