LSTM simplified

← back
17 min read· 21 Sept 2023
LSTM simplified
Contents

Introduction#

In my previous blog about RNNs - "RNNs a walkthrough", we saw how recurrent neural networks worked, their limitations, like vanishing gradients, which leads to failure of learning long sequences. LSTMs try to solve for these limitations. In this blog we will go through detailed walkthrough of underlying structure of an LSTM cell.

This is my effort to simplify each of the blocks of LSTM piece by piece. In the end we join all these pieces to answer the question - How is LSTM an add on over RNNs?

The Gate#

Believe me or not, the image above is the key to understanding the fundamental blocks of LSTM. This guard is protecting the gate above. For now let us assume that you will not be allowed through the gate without a small cost. Consider the same scenario with vectors. A vector cannot pass through certain path without a small cost. In the diagram below you can see that each value is reduced by some factor when it passes through the gate.

How do we mathematically achieve this? Simple, using a sigmoid layer. Range of Sigmoid function is between 0 and 1. (For now let us not worry about what is the input to this sigmoid layer)

Notice how each element in the input vector vv is being multiplied with corresponding element in vector tt which is output of sigmoid layer to give resultant vector vv{'}. This operation is called point wise operation. In this case we are performing point wise multiplication operation. This operation depicted above allows partial passage to the input vector.

What if we do not want input vector to pass through this special gate?

  • Vector tt should be [0,0,0,0,0]

What if we want input vector to pass without any change?

  • Vector tt should be [1,1,1,1,1]

How do we concisely represent this “Gate”?#

  1. vectors vv and tt have dd dimensions
  2. ⓧ here is the gate
  3. tt has elements in the range of 0 and 1 after sigmoid activation
  4. upon point wise multiplication a “fee” is collected to get new resultant vector vv'

This procedure is critical to understand the gravity of the underlying structure of LSTMs, as this mechanism forms the basis of each fundamental block that we are going to discuss.

Some notations before we begin#

Lego pieces#

An LSTM unit consists of 4 blocks, 3 blocks adjacent to each other and 1 block that connects them all each block has its own significance. In order to get an functional understanding let’s first look at each lego block individually, then we can club these pieces to form a complete LSTM unit.

Separate Lego blocks of an LSTM
Separate Lego blocks of an LSTM
Joined Lego blocks of an LSTM
Joined Lego blocks of an LSTM

A Glimpse of an LSTM cell#

While LSTM cell is similar to vanilla RNN cell, one of the key distinction is that each state has 2 inputs from the previous time step, “cell state” ct1c_{t-1} and “hidden state” ht1h_{t-1}. All other concepts like unrolling a recurrent neural network, back propagation through time remain the same.

Unrolling LSTM cell for 3 timesteps - recursively consume outputs of previous state and current input to generate new state for current time step.

NOTE : it is the same LSTM cell that is being viewed at 3 time steps

A Math problem#

I will be using a simple analogy to explain LSTMs. Consider you are in school and are being taught simple math.

Situation 1 - In classroom while learning you solve the problem:

  • e.g - 10 apples cost 200Rs, what is the cost of 2 apples?
  • Let us call this APPLE question

Situation 2 - In a test you are given a problem:

  • e.g - 10 bananas cost 100Rs, what is the cost of 3 bananas?
  • Let us call this BANANAS question

Steps to solve:

  1. cost of one object is cc
  2. cost of nn objects is n×cn\times c
  3. solve for cc using given information
  4. find cost of tt objects using t×ct\times c

💡 In order to solve a problem all that matters, is that you are able to recall the concept and use previously learnt steps to solve equivalent problem in future.

Though a trivial problem, it is important to list these steps down while trying to impose these steps on internal blocks of LSTMs.

Opening the Gates#

Gate #1 - Forget Gate#

What would your initial steps be while solving the BANANAS problem in exam?

  1. You retain the
    1. information you got from the new question
    2. for instance data mentioned in the bananas question
  2. You recall that
    1. you were taught to find cost of one apple
    2. you need to multiply cost of one apple with N apples in the question
  3. You forget the
    1. problem is about apples or bananas (It is actually about finding the cost of an object)
    2. answer to question you did in class (as it is not relevant for bananas problem)

Now let us impose this analogy on LSTMs.

FtF_{t}output of Forget gate
WfW_{f}weights associated with Forget gate
bfb_fbias associated with Forget gate
ftf_tsigmoid activated vector associated with forget gate
[ht1, xt][h_{t-1},\ x_t]concat vectors ht1h_{t-1} , xtx_t
ht1h_{t-1}previous hidden state
ct1c_{t-1}previous cell state
xtx_tcurrent input
  1. retain
    • input xtx_t and previous time step output ht1h_{t-1}
    • like data mentioned in the bananas question
  2. recall
    • using cell state ct1c_{t-1} represents previous information that you learnt
    • like steps to solve the problem
  3. forget
    • some specifics partially which are not necessary, using the forget gate
    • like is it apple or banana problem

It partially forgets previous and current information, here concatenated vector [ht1,xt][h_{t-1}, x_t]. Hence the name “Forget Gate”.

Now try to recall what we learnt in the earlier section about gates, it will cost you a certain fee to pass through. Same concept will be used here in order to “forget” information partially.

💡 But who decides what degree of information should be forgotten?

This is where weights and learning comes into play. Lets take a closer look at forget gate. Mathematically forget gate can be represented as:

  • filter for “Forget Gate” - sigmoid activation
ft=σ(Wf . [ht1,xt]+bf)f_t = \sigma{(W_f\ .\ [h_{t-1}, x_t]+b_f)}
  • output of “Forget Gate” - point wise multiplication
Ft=ct1  ftF_t = c_{t-1}\ \otimes\ f_t

Notice how decision on degree of information that is to be forgotten is a function of 2 vectors ht1h_{t-1} and xtx_t. We are forgetting previously learnt information partially on the basis of previous time step output as well as input of the current time step.

Now that we have passed through Gate #1 let us close it behind us and move to the next one.

Gate #2 - Input Gate#

Again let us ask the same question now that you remember how to solve the problem, also you understand that it is a problem that involves calculating cost of single unit, what would your next step be?

  1. You recall that
    1. you need to setup a new equation based on new data
    2. this equation can be built using previous knowledge
  2. You set up
    1. a new equation to plug in values
  3. You forget the
    1. older equation (partially)
    2. values inserted in the older equation to solve apples question (completely)

Now let us try to see what LSTM does in this second stage.

FtF_{t}output of Forget gate
ItI_toutput of Input gate
WiW_{i}weights associated with Input gate
WcW_cweights associated with Input gate
iti_tsigmoid activated vector associated with input gate
ct~\tilde{c_t}tanh activated vector associated with input gate
  1. recall and retain
    1. input xtx_t and previous time step output ht1h_{t-1}
    2. like recalling previous equation to solve apples problem and retaining bananas problem
  2. set up
    1. here tanhtanh activated vector ct~\tilde{c_t} can be a good representation of this step
    2. like coming up with equation for solving the bananas problem
  3. filter
    1. here point wise multiplication of iti_t and ct~\tilde{c_t} represents filtering
    2. like altering values inserted in the older equation to solve apples question

Very similar to “Forget Gate”, here LSTM tries to determine what degree of new information should be passed on to the next gate. This stage determines how to update old cell state of LSTM and by what amount. Weights which control Input gates are WiW_i which help with filtering of input and WcW_c which is associated with generating intermediate cell state ct~\tilde{c_t} of LSTM.

Mathematically it can be represented as

  • create a filter for “Input gate”
it=σ(Wi . [ht1,xt]+bi)i_t = \sigma{(W_i \ .\ [h_{t-1}, x_t] + b_i)}
  • tanh activation for intermediate cell state generation of current time step
ct~=tanh(Wc . [ht1,xt]+bi)\tilde{c_t} = tanh(W_c \ .\ [h_{t-1}, x_t] + b_i)
  • Output of “Input gate”
It=ct~itI_t = \tilde{c_t} \otimes i_t

Using the output of “Forget gate” FtF_t and “Input gate” ItI_t , a new state or a vector is formed called cell state. It essentially holds new information with some learnings from the past. Very much like us humans. We add new learnings on top of our knowledge bank while we retain our learnings from the past.

This cell state can be mathematically represented as

ct=FtItc_t = F_t \oplus I_t

Notice how we are adding to previous knowledge FtF_t using point wise operation here. This cell state is further used by “Output Gate” as well as the next timestep.

Gate #3 - Output Gate#

Now we have the equation set up for us all we need to do is plug in the value and get the answer for our bananas problem.

  1. You retain
    1. values of bananas problem
  2. You solve
    1. for new answer according to bananas problem
    2. (using previously learnt concept like multiplication)
  3. You filter
    1. imagine you are required to just fill in the final answer into the test
    2. steps are not relevant here, so you filter that and report just the final answer

Let us look how LSTMs generate final output.

WoW_{o}weights associated with Output gate
oto_tsigmoid activated vector associated with Output gate
ctc_tcell state associated with current time step
ctc_t'tanh activated vector associated with Output gate
hth_thidden state output associated with current time step
  1. recall and retain
    1. input xtx_t and previous time step output ht1h_{t-1} → this step remains same
    2. like recalling previous method to get final answer and retaining bananas problem
  2. solve
    1. here tanhtanh activation of cell state is a good representation of process of getting the answer
    2. like solving for cost of n objects
  3. filter
    1. creating a filter similar to earlier gates here termed as oto_t
    2. filtering ctc_t' using oto_t to get hth_t which is the final output of current state
    3. like reporting just the final answer

The goal of this gate is to generate hidden state output hth_t associated with current time step. Filtering using sigmoid activation remains the same as seen in earlier gates, only the weights and biases associated with this gate change.

Mathematically this stage can be represented as:

  • create a filter for “Output gate”
ot=σ(Wo . [ht1,xt]+bo)o_t = \sigma{(W_o \ .\ [h_{t-1}, x_t] + b_o)}
  • tanh activation associated with “Output gate”
ct=tanh(ct)c_t' = tanh(c_t)
  • output of “Output gate”
ht=otcth_t = o_t \otimes c_t'

Notice how

  1. final output hth_t is function of transformed current cell state ctc_t'
  2. in earlier gates we saw that how ctc_t' is function of previous cell state ct1c_{t-1} and current cell state ctc_t
  3. This happens in a recursive fashion just like vanilla RNNs

All these transformations are good but what exactly makes LSTM so special? For this we will need to ask ourselves how did LSTMs solve for issues which vanilla RNNs faced.

The missing link#

In this section we will try to answer how does LSTM solve for “long term dependencies”. The easiest way to explain the piece that binds it all would be using how gates can be utilised in order to retain long term information.

For now we are concerned with two gates forget gate and input gate. Clearing rest of the paths in LSTM for ease of explaining.

Equations of LSTM at various timesteps

timestep = t

  • Output of forget gate
Ft=ftct1F_{t} = f_{t}\otimes c_{t-1}
  • Output of input gate
ct=FtItc_{t} = F_{t} \oplus I_{t} It=itct~I_t = i_{t} \otimes \tilde{c_{t}}

These equations can be combined to get

ct=[ftct1][itct~]c_{t} = [f_{t}\otimes c_{t-1}] \oplus [i_{t} \otimes \tilde{c_{t}}]

Let us see the same set of equations at various timesteps

timestep = t+1

ct+1=[ft+1ct][it+1ct+1~]\begin{align} c_{t+1} = [f_{t+1}\otimes c_{t}] \oplus [i_{t+1} \otimes \tilde{c_{t+1}}] \end{align}

timestep = t

ct=[ftct1][itct~]\begin{align} c_{t} = [f_{t}\otimes c_{t-1}] \oplus [i_{t} \otimes \tilde{c_{t}}] \end{align}

timestep = t-1

ct1=[ft1ct2][it1ct1~]\begin{align} c_{t-1} = [f_{t-1}\otimes c_{t-2}] \oplus [i_{t-1} \otimes \tilde{c_{t-1}}] \end{align}

Notice how cell state ctc_t at timestep tt is always a function of addition of two gated products, here forget gate ftf_t and input gate iti_t.

Recall here

  • how do we allow entire vector unchanged through the gate?
    • vector of 1s - [1,1,1,1,1]
  • how do we stop entire vector to pass through the gate?
    • vector of 0s - [0,0,0,0,0]

Similar concept is applied here

How can we stop vectors at forget gate?

  • when ft=[0,0,0,0,0]f_t = [0,0,0,0,0]
  • what happens to output of forget gate?

Ft=ftct1F_{t} = f_{t}\otimes c_{t-1}\newline Ft=[0,0,0,0,0]ct1F_{t} = [0,0,0,0,0]\otimes c_{t-1}\newline Ft=[0,0,0,0,0]\therefore F_t = [0,0,0,0,0]

How can we stop vectors at input gate?

  • when it=[0,0,0,0,0]i_t = [0,0,0,0,0]
  • what happens to output of input gate?

It=itct1~I_t = i_t \otimes \tilde{c_{t-1}}\newline It=[0,0,0,0,0]ct1~I_t = [0,0,0,0,0] \otimes \tilde{c_{t-1}}\newline It=[0,0,0,0,0]\therefore I_t = [0,0,0,0,0]

Consider this simplified diagram where 3 time steps of LSTM are shown - let us view how information flows with respect to time step t+1t+1. The stickman figure at time step t+1t+1 needs some information from previous time along the blue path shown below, let us see if LSTM is able to provide the same.

For ease of explanation

  • ft=cf_t = c means ft=[c,c,c,c,c]f_t = [c,c,c,c,c]
  • similarly for It=cI_t = c

where c=0c = 0 or 11

Case #1#

  • ft=1,it=0f_t = 1, i_t = 0

💬 i ONLY need ct1c_{t-1} NOT ctc_t

This case is what enables LSTM to carry over long term memory.

Some math in action:

ct=FtItc_{t} = F_{t} \oplus I_{t}

ct=[1ct1]0;whereft=1it=0c_{t} = [1 \otimes c_{t-1}] \oplus 0 ; where f_t=1 i_t=0

ct=ct1c_{t} =c_{t-1}

As we can see this type of gate setting in LSTM serves as a skip connection for cell state, skipping one or more states in between to reach to required time step.

Imagine we need cell state from time step (tw)(t-w), all the intermediate gates between time step (tw)(t-w) and (t+1)(t+1) will be set to fq=1,iq=0f_{q}=1, i_{q}=0 where qq is index of intermediate time step.

Following diagram illustrates the same with

Case #2#

  • ft=0,it=1f_t = 0, i_t = 1

💬 i do NOT need cell states before time step tt

Again some math:

ct=FtItc_{t} = F_{t} \oplus I_{t}

ct=[ftct1][itct~];whereft=0,it=1c_{t} = [f_t\otimes c_{t-1}] \oplus [i_t \otimes \tilde{c_{t}}] ; where f_t=0,i_t=1

ct=[0ct1][1ct~]c_{t} = [0\otimes c_{t-1}] \oplus [1 \otimes \tilde{c_{t}}]

ct=[1ct~]c_{t} = [1 \otimes \tilde{c_{t}}]

Here we can see that no vector passes through the forget gate and we use ONLY the output of input gate to generate cell state at time tt to be consumed at time step t+1t+1.

Case #3#

  • ft=0,it=0f_t = 0, i_t = 0

💬 i need a break

This is where LSTM

  • erases all previous information
  • and does NOT learn anything new

Mathematically

ct=FtItc_{t} = F_{t} \oplus I_{t}

ct=[ftct1][itct~];whereft=0,it=0c_{t} = [f_t\otimes c_{t-1}] \oplus [i_t \otimes \tilde{c_{t}}] ; where f_t=0,i_t=0

ct=[0ct1][0ct~]c_{t} = [0\otimes c_{t-1}] \oplus [0 \otimes \tilde{c_{t}}]

ct=0c_{t} = 0

Here we can see that we allow both the vectors ct1,ct~c_{t-1}, \tilde{c_t} to pass through the forget gate and input gate respectively. Here we get information of both current state and states before it as some additive form of vectors mentioned above.

Case #4#

  • ft=1,it=1f_t = 1, i_t = 1

💬 i need gist of all the cell states in previous time steps

(can’t catch a break now can i?)

Mathematically

ct=FtItc_{t} = F_{t} \oplus I_{t}

ct=[ftct1][itct~];whereft=1,it=1c_{t} = [f_t\otimes c_{t-1}] \oplus [i_t \otimes \tilde{c_{t}}] ; where f_t=1,i_t=1

ct=[1ct1][1ct~]c_{t} = [1\otimes c_{t-1}] \oplus [1 \otimes \tilde{c_{t}}]

ct=[ct1ct~]c_{t} = [c_{t-1} \oplus \tilde{c_{t}}]

Here we can see that we allow both the vectors ct1,ct~c_{t-1}, \tilde{c_t} to pass through the forget gate and input gate respectively. Here we get information of both current state and states before it as some additive form of vectors mentioned above.

End note#

This brings us to the end of this long blog. We covered how LSTMs use gated flow of vectors to retain both long and short term information, but there is still some ground left to cover, like how LSTM actually solve for vanishing gradient problem. This part is covered in detail in the references mentioned below. I hope reading this gave an in depth as well as intuitive understanding of LSTMs.

One final “diagrammatic” takeaway before ending this blog

This is how LSTMs keep a track of both long and short term memory

References#

  1. This blog Understanding LSTM Networks by Christopher Olah is gold standard and helped me understand internal working of LSTMs. Huge shoutout. My blog is mere interpretation of the concepts explained here.
  2. CS Toronto slides
  3. Detailed walkthrough LSTMs (Video series)
    1. Back propagation and LSTM equations
    2. How LSTMs solve vanishing gradients
  4. How LSTM networks solve the problem of vanishing gradients
  5. Back propagation using dependency graph
  6. How the LSTM improves the RNN
Written by Sagar Sarkale