IKH

Vanishing and Exploding Gradients in RNNs

Note: This is a test-only page introducing some important practical problems in training RNNs. We have introduced the basic idea on this page and have put the detailed discussion(which involves some detailed algebra) as an optional session in the interest of time. We highly recommend going through the optional session which is the last session of this module.

The Gradient Propagation Problem in RNNs

Although RNNs are extremely versatile and can (theoretically) Learn extremely complex functions, the biggest problem is that they are extremely hard to train (especially when the sequences get very long)

RNNs are designed to learn patterns in sequential data, i.e. patterns across ‘time’. RNNs are also capable of learning what are called long-term dependencies. For example, in a machine translation task, we expect the network to learn the grammar of the languages, etc. This is accomplished through the recurrent layers of the net – each state learns the cumulative knowledge of the sequence seen so far by the network.

Although this feature is what makes RNNs so powerful, it introduces a severe problem -as the sequences become longer, it becomes much harder to backpropagate the errors back into the network. The gradients ‘die out’ by the time they reach the initial time steps during backpropagation.

RNNs use a slightly modified version of backpropagation to update the weights. In a standard neural network, the errors are propagated from the output layer to the input layer. However, in RNNs, errors are propagated not only from right to left but also through the time axis.

Refer to the figure above – notice that the output yt is not only dependent on the input x1, but also on the inputs xt_1, xt_2, and so on till x1. thus, the loss at time t depends on all the inputs and weights that appear before t. For e.g. if you change W2F, it will affect all the outputs a21, a22, a23, which will eventually affect the output yt (through a long feedforward chain).

This implies that the gradient of the loss at time t needs to be propagated backwards through all the layers and all the time steps. To appreciate the complexity of this task, consider a typical speech recognition task – a typical spoken English sentence may have 30-40 words, so you have to backpropagate the gradients through 40 time steps and the different layers.

This type of backpropagation is known as backpropagation through time or BPTT.

The exact backpropagation equations are covered in the optional session, though here let’s just understand the high-level intuition. The feedforward equation is:

alt=f(WlFal−1t+WlRalt−1+bl)

Now, recall that in backpropagation we compute the gradient of subsequent layers with respect to the previous layers. In the time dimension, we compute the gradient of output at time t with respect to the output at t – 1, i.e.  ∂alt∂alt−1. This quantity depends linearly on WR, so  ∂alt∂alt−1 is some function of WR:

∂alt∂alt−1 = g(WR)

Similarly, extending the gradient one time step backwards, the gradient ∂alt−1∂alt−2 will also be a function of WR, and so on. The problem is that these gradients are eventually multiplied with each other during backpropagation (∂alT∂alT−1.∂alT−1∂alT−2…..∂al2∂al1), and so the matrix WR is raised to higher and higher power of its own, (WR)n, as the error propagates backwards in time.

The longer the sequence, the higher the power. This leads to exploding or vanishing gradients. If the individual entries of WR are greater than one, the values in (WR)n will explode to extremely large values; if the entries are lesser than one, (WR)n will make them extremely small.

This problem seriously impedes the learning mechanism of the neural network.

You could still use some workarounds to solve the problem of exploding gradients. You can impose an upper limit to the gradient while training, commonly known as gradient clipping. By controlling the maximum value of a gradient, you could do away with the problem of exploding gradients.

But the problem of vanishing gradients is a more serious one. The vanishing gradient problem is so rampant and serious in the case of RNNs that it renders RNNs useless in practical applications. One way to get rid of this problem is to use short sequences instead of long sequences. But this is more of a compromise than a solution – it restricts the applications where RNNs can be used.

To get rid of the vanishing gradient problem, researchers have been tinkering around with the RNN architecture for a long while. The most notable and popular modifications are the long short-term memory units (LSTMs) and the gated recurrent units (GRUs). You’ll study both these architectures in the next session.

To learn about backpropagation through time and the problem of vanishing and exploding gradients in details, you can refer the session – ‘Additional Resources’, the last session of this module.

In the next section, you’ll go through the summary of the session.

Report an error