Please find torch implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/15/rnn_torch.ipynb
Recurrent neural networks
We show how to implement RNNs from scratch. Based on sec 8.5 of http://d2l.ai/chapter_recurrent-neural-networks/rnn-scratch.html.
|████████████████████████████████| 184 kB 4.2 MB/s eta 0:00:01
|████████████████████████████████| 136 kB 47.1 MB/s
|████████████████████████████████| 72 kB 777 kB/s
Data
As data, we use the book "The Time Machine" by H G Wells, preprocessed using the code in this colab.
Model
We fit an unconditional RNN, for language modeling (ie. not vec2seq or seq2seq. Following the D2L notation, the model has the form where is the matrix of (one-hot) inputs (for batch size and vocabulary size ), is the matrix of hidden states (for hidden states), and is the matrix of output logits (for output labels, often ).
Prediction (generation)
We pass in an initial prefix string, that is not generated. This is used to "warm-up" the hidden state. Specifically, we update the hidden state given the observed prefix, but don't generate anything. After that, for each of the T steps, we compute the (1,V) output tensor, pick the argmax index, and append it to the output. Finally, we convert the to indices to readable token sequence of size (1,T). (Note that this is a greedy, deterministic procedure.)
Training
To ensure the gradient doesn't blow up when doing backpropagation through many layers, we use gradient clipping, which corresponds to the update where is the scaling parameter, and is the gradient vector.
The training step is fairly standard, except for the use of gradient clipping, and the issue of the hidden state. If the data iterator uses random ordering of the sequences, we need to initialize the hidden state for each minibatch. However, if the data iterator uses sequential ordering, we only initialize the hidden state at the very beginning of the process. In the latter case, the hidden state will depend on the value at the previous minibatch.
The state vector may be a tensor or a tuple, depending on what kind of RNN we are using. In addition, the parameter updater can be an optax optimizer, or a simpler custom sgd optimizer.
The main training function is fairly standard. The loss function is per-symbol cross-entropy, , where is the model prediction from the RNN. Since we compute the average loss across time steps within a batch, we are computing . The exponential of this is the perplexity (ppl). We plot this metric during training, since it is independent of document length. In addition, we print the MAP sequence prediction following the suffix 'time traveller', to get a sense of what the model is doing.
Creating a Flax module
We now show how to use create an RNN as a module, which is faster than our pure Python implementation.
While Flax has cells for more advanced recurrent models, it does not have a basic RNNCell. Therefore, we create an RNNCell similar to those defined in flax.linen.recurrent
here.
Now, we create an RNN module to call the RNNCell for each step.
Now we update the state with a random one-hot array of inputs.
Now we define our model. It consists of an RNN Layer followed by a dense layer.
Test the untrained model.
Train it. The results are similar to the 'from scratch' implementation, but much faster.