Path: blob/master/deep_learning/rnn/2_tensorflow_lstm.ipynb
1480 views
LSTM (Long Short Term Memory)
Now that we've understand the motivation behind Recurrent Neural Network (RNN) and seen its implementation. Let's turn our head towards its more powerful variants.
Recall that RNNs are networks with loops in them, allowing them to store information about the previous state and potentially leverage it to better reason about the current state. One popular diagram that we might come across for RNNs is the following:
In the diagram above, a chunk of neural network (our RNN layer) takes some input and outputs a value . This loop denotes the network will be repeating the process for every sequence in out input. In other words, when given a sentence of 4 words, the network (the RNN cell) will unrolled itself into 4 copies, one copy for each word.
The main issue with these vanilla RNN is that they tend to suffer from the vanishing gradient problem. Training a RNN is similar to training a traditional Neural Network, we also use the backpropagation algorithm, but with a little twist. Because the parameters are shared by all time steps in the network, the gradient at each output depends not only on the calculations of the current time step, but also the previous time steps. For example, in order to calculate the gradient at t=4 we would need to backpropagate 3 steps and sum up the gradients. This is called Backpropagation Through Time (BPTT). Thus we can imagine when computing the gradient, if we multiply a small number with another small number with another and with another, the value dramatically decays to 0, and the weights will no longer be updated when the gradients are 0. To mitigate this issue, other variants of RNNs were developed, and this notebook will look at one of them called LSTM (Long Short Term Memory).
LSTM Step by Step Walk Through
Note that a large portion of the content for this section is based on Blog: Understanding LSTM Networks
Given , the input at time step and , the hidden state at time step the computation happening in a vanilla RNN cell are as follow:
Here, is usually a nonlinearity function such as tanh. LSTMs also have this chain like structure, but the repeating module has a different structure. Instead of having a single set of weights and connecting the input and hidden state respectively, there are four sets of weights, interacting in a very special way.
LSTM are designed to avoid long-term dependency problems, and the core idea is the cell state, the horizontal line running through the top of the diagram.
Cell state is kind of like a conveyor belt. It runs straight down the entire chain, with only some linear interactions, making it easier for information to flow along it unchanged. LSTMs have the ability to remove or add information to the cell state, carefully regulated by structures called gates. Namely, the forget gate, input gate and output gate.
The first step for a LSTM cell is to decide what information we're going to throw away from the cell state. This is determined by the a sigmoid layer "forget gate", the forget gate looks at and , and outputs a number between 0 and 1 for each number in the cell state . A 1 represents completely keep this while a 0 represents completely get rid of this.
Note that is a simplified notation for , denotes that these are the set of weights for the forget gate.
For example, if we are building a language model that's trying to predict the next word based on all the previous ones. In such a problem, the cell state might include the gender of the present subject, to determine the correct pronouns to use. When we see a new subject, we want to forget the gender of the old subject.
The second step is to determine what new information to store in the cell state. This step consists of two parts, first a sigmoid layer known as the "input gate" decides which value we'll update. Second, a tanh layer creates a vector of new candidate value , that could be added to the state.
In the example of our language model, we want to add the gender of the new subject to the cell state to replace the old one we're forgetting.
It's now time to update the old cell state, , into the new cell state . The previous steps already decided what to do, we just need to actually do it. We multiply the old state by , forgetting the things we've decided to forget earlier. Then we add , which is the new cell state scaled by how much we've decided to update each value.
Looking at this formula more carefully, we can see that the information carried by the previous cell state, will not be lost if its weight, i.e. the forget gate is on (close to 1), making LSTM better at learning long-term dependencies compared to vanilla RNN.
In the case of the language model, this is where we'd actually drop the information about the old subject's gender and add the new information, as we decided in the previous steps.
Finally, we need to decide what we're going to output. This output will be a filtered version of our cell state. First, we run it through a sigmoid layer which decides what parts of the cell state we're going to output. This is essentially our output gate. Then, we put the cell state through tanh (to push the values to be between â1 and 1) and multiply it by the output of the sigmoid gate, so that we only output the parts we decided to.
For the language model example, since it just saw a subject, it might want to output information relevant to a verb, in case that's what is coming next. Or it might output whether the subject is singular or plural, so that we know what form a verb the next word should take form.
Implementation
A lot of the scripts is similar to that of the implementation for vanilla RNN.
The formulas for LSTM is listed again for quick reference, note that for the implementation, we've excluded the bias term to keep things simpler.
Conclusion
Given the more complex structure, it makes sense that it takes longer for LSTMs to train compared to vanilla RNN. But, thankfully, it does give better performance on the test set.
Most of exciting result today achieved by RNN-like networks is in fact achieved by LSTMs, because of its capability to deal with long term dependency. The long term dependency problem is that, when we have larger network through time, the gradient decays quickly during back propagation. So training a RNN having long unfolding in time becomes impossible. But LSTM avoids this decay of gradient problem by allowing us to make a super highway (cell states) through time, these highways allow the gradient to freely flow backward in time making them less susceptible to vanishing gradients.