RNNs a walkthrough

← back
9 min read· 22 Aug 2023
RNNs a walkthrough
Contents

Introduction#

Looking back from a point where GPT models have sent waves across the industry, recurrent networks have been one of the most fundamental blocks that have been laid down, when it comes to progress of natural language and deep learning. This blog aims to give a functional understanding of RNNs. Personally, I had a difficult time wrapping my head around this when i learnt it the first time. I believe that this blog would reduce the friction for the people who have just started their journey.

This blog is first in the series of Recurrent neural networks, LSTMs are extension of RNNs and we will go through a brief overview of RNNs first and understand their working and significance.

Why do we have recurrent networks?#

Before jumping into a space of recurrent networks let us try to understand why do we need them in the first place?

Reason 1#

For instance, you are a mobile phone’s keyboard that suggests a next word to a user. What would you try to learn in this case?

  • Is the person frequently using a group of words?
    • This seems interesting
    • Woah that’s awesome
  • How is a person talking to his friends or colleagues.
    • Hey, i am free tonight, lets catchup for games
    • Hi, let us set up a meeting tomorrow morning
  • Is the sequence of some words same most of the time?

Notice how each word before gaming is setting up the context. You could suggest these words:

  • Hey, i am free tonight, lets catchup for [games | walk | drinks]
  • Hi, let us set up a meeting tomorrow [morning | afternoon | evening]

In order to give these word suggestions, you will have to remember these sequences, or to be specific the meaning that they convey.

Traditional neural networks are forgetful in nature, consider you are pass a sentence - “Hello are you free tomorrow evening and you are passing this sentence word by word, by the time you pass the word “tomorrow” the network might have forgotten what words came prior to it. The network might learn which words occur most frequently with word tomorrow instead. Consider these sentences.

  • Tomorrow I have a meeting”
  • Tomorrow I have a holiday”
  • Tomorrow I won’t be available”

:bulb: Notice how using the word “Tomorrow” at the beginning and in between the sentence changes the next word prediction.

Say this is a general pattern for a person and his usage of word “Tomorrow”, if we use traditional neural networks we might suggest word I for the sentence “Hi, let us set up a meeting tomorrow ____

Reason 2#

We saw that we were not able to retain words using traditional neural networks, one might argue what happens if we show all words of sentence to neural network at the same time?

Say you want to classify text into 3 sentiments - positive, negative and neutral. And you are given a traditional neural network. Consider these 2 sentences.

  • It is a good product” → positive
  • It is a good product but it stopped working in few days” → negative

We can clearly see that each sentence has different number of words. It would hence change the dimensionality of vector representation for each sentence. If we truncate the sentence we lose meaning of the sentence, if we pad a sentence with dummy vectors we are increasing complexity of the network by increasing dimensionality.

To summarise we need a network which can:

  1. retain previously learnt information and build on top of that.
  2. handle sequences of variable length without increasing the dimensionality

Recurrent Network an overview#

Some points to note before we begin this:

  • Neural network is a computation unit
  • Loop computation unit to get recurring neural networks

When you read about recurrent networks, this is a common diagram that you will stumble upon. Shown below is a simple recurrent cell.

Let us call this recurrent unit as cell QQ. Think of QQ as a computation unit, it consumes some input xx and generates output oo. This cell QQ is building block of recurrent neural network.

The loop shown above allows the network to carry previous information. Let us see how this network unrolls at each iteration (here iteration is called timestep). One thing to bear in mind before we proceed is that it is the same cell QQ which we will be seeing at various timesteps.

t=0

At timestep 0 we do not have any previous information, hence a dotted line shown here indicates initialization of previous state.

t=1

At timestep 1 we do have previous information from timestep 0, we can consume this previous information to generate new output. For now let us not get into details as to what information is being passed, we will cover it in next section.

t=2

At timestep 2 we have previous information from timestep 1, we consume the same, and generate output in the similar way.

The same process continues for subsequent timesteps, until we run out of input.

Basic RNN cell#

Now that we have some understanding of how information is being passed from previous states, let us look into detail to understand what hidden state hh actually does.

Simplified view of RNN cell

xtx_{t} = input at timestep tctc_{t} = current cell state
oto_t = output at timestep tct1c_{t-1} = previous cell state

Following steps occur when we pass input xtx_t to RNN cell:

  1. Add previous cell state ct1c_{t-1} with current input xtx_t
    • This addition operation ensures that sequential dependencies are taken care of
    • This is where we take into account previous hidden state
  2. Pass above output through non linear activation ff here tanh
    • This nonlinearly activated output is called cell state ctc_{t}
  3. Pass ctc_t through gg here softmax to get current timestep output oto_t
  4. The same cell state ctc_t will serve as previous state information for timestep t+1t+1

NOTE : At timestep tt RNN uses only cell state of t1t-1, we do not have any information about states preceding t1t-1.

Weights#

This blog would be incomplete without a brief mention about weights involved in RNNs and how back propagation happens over each timestep.

There are 3 major weight matrices involved with RNNs. These weights are essentially what RNNs learn.

  1. WxW_x - Input to Hidden Weights
    • It determines how current input will influence the current cell state
  2. WhW_h - Hidden to Hidden Weights
    • It determines how previous cell state will influence current cell state
  3. WoW_o - Hidden to Output Weights
    • It determines the output based on current cell state

:bulb: Again it is worth noting that, same 3 weights are shared across all timesteps.

RNN can be represented using these equations

ct=f(Wx . xt+Wh . ct1)c_t = f(W_x\ .\ x_t + W_h\ .\ c_{t-1} )

ot=g(Wo . ct)o_t = g(W_o\ .\ c_t)

Where ff is tanh and gg is softmax activation.

Back propagation through time#

We will look at back propagation for one of the weights WhW_h i.e the weight matrix associated with previous state. This is the most crucial set of parameters in case of RNNs.

Like any other back propagation we sum up individual losses. In case of RNNs we will sum up losses associated with each timestep and we will learn how these losses are varying with respect to weight matrices.

LWh=t=0tLtWh\frac{\partial{L}}{\partial{W_h}} = \sum_{t=0}^{t} \frac{\partial{L_t}}{\partial{W_h}}

Let us look at term t=3t=3 of t=0tLtWh\sum_{t=0}^{t} \frac{\partial{L_t}}{\partial{W_h}}

This is how RNN cell equations will look like at t=3t=3

o3=g(Wo . c3)o_3 = g(W_o\ .\ c_3)

c3=f(Wo . x3+Wh . c2)c_3 = f(W_o\ .\ x_3 + W_h\ .\ c_{2} )

We can write loss L3L_3 w.r.t WhW_h as

L3Wh=L3o3.o3Wh\frac{\partial{L_3}}{\partial{W_h}} = \frac{\partial{L_3}}{\partial{o_3}} . \frac{\partial{o_3}}{\partial{W_h}}

Notice

  • o3o_3 is dependent on c3c_3
  • c3c_3 is dependent on c2c_2

Hence we need to apply chain rule of differentiation

L3Wh=L3o3.o3c3.c3c2c2Wh\frac{\partial{L_3}}{\partial{W_h}} = \frac{\partial{L_3}}{\partial{o_3}} . \frac{\partial{o_3}}{\partial{c_3}} . \frac{\partial{c_3}}{\partial{c_2}}\frac{\partial{c_2}}{\partial{W_h}}

If we extend the same observation

  • c2c_2 is dependent on c1c_1
  • We can apply the same steps across for all timesteps

Hence we get following equation

L3Wh=L3o3.o3c3.c3c2.c2c1.c1c0.c0Wh\frac{\partial{L_3}}{\partial{W_h}} = \frac{\partial{L_3}}{\partial{o_3}} . \frac{\partial{o_3}}{\partial{c_3}} . \frac{\partial{c_3}}{\partial{c_2}}. \frac{\partial{c_2}}{\partial{c_1}} . \frac{\partial{c_1}}{\partial{c_0}} . \frac{\partial{c_0}}{\partial{W_h}}

Similar calculations will be done for L2Wh,L1Wh,L0Wh\frac{\partial{L_2}}{\partial{W_h}} , \frac{\partial{L_1}}{\partial{W_h}} , \frac{\partial{L_0}}{\partial{W_h}} in order to get t=0tLtWh\sum_{t=0}^{t} \frac{\partial{L_t}}{\partial{W_h}}

As we can see here BPTT (Back propagation through time) is nothing but sum of back propagations for each time step.

Limitations of RNNs#

  1. Vanishing or exploding gradients

We saw how back propagation equation can be written for RNNs. Let us see how does it look for large sequences.

L30Wh=L30o30.o30c30.c30c29.c29c28c1c0.c0Wh\frac{\partial{L_{30}}}{\partial{W_h}} = \frac{\partial{L_{30}}}{\partial{o_{30}}} . \frac{\partial{o_{30}}}{\partial{c_{30}}} . \frac{\partial{c_{30}}}{\partial{c_{29}}}. \frac{\partial{c_{29}}}{\partial{c_{28}}} \dots \frac{\partial{c_1}}{\partial{c_0}} . \frac{\partial{c_0}}{\partial{W_h}}

Notice how we are getting a long chain of multiplications

  • if the gradients are smaller are between 0 and 1, it would lead to vanishing gradients
  • if the gradients are greater than 1, gradients might explode
  1. RNNs are not able to learn long sequences

While RNNs are good at learning shorter sequences, their performance degrades with longer sequences. With longer sequences gradients tend to get smaller and hence hampers network’s ability to learn. This might result into very slow training.

References#

  1. RNNs : Deep Learning Book by Aaron Courville, Ian Goodfellow, and Yoshua Bengio
  2. RNNs : Coursera Sequence Models
  3. BPTT : Backpropagation Through Time Overview - Sebastian Raschka
  4. BPTT : Backpropagation Through Time for Recurrent Neural Network - Mustafa Murat ARAT
Written by Sagar Sarkale