109

I'm having trouble understanding the documentation for PyTorch's LSTM module (and also RNN and GRU, which are similar). Regarding the outputs, it says:

Outputs: output, (h_n, c_n)

  • output (seq_len, batch, hidden_size * num_directions): tensor containing the output features (h_t) from the last layer of the RNN, for each t. If a torch.nn.utils.rnn.PackedSequence has been given as the input, the output will also be a packed sequence.
  • h_n (num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t=seq_len
  • c_n (num_layers * num_directions, batch, hidden_size): tensor containing the cell state for t=seq_len

It seems that the variables output and h_n both give the values of the hidden state. Does h_n just redundantly provide the last time step that's already included in output, or is there something more to it than that?

kmario23
  • 57,311
  • 13
  • 161
  • 150
N. Virgo
  • 7,970
  • 11
  • 44
  • 65

5 Answers5

236

I made a diagram. The names follow the PyTorch docs, although I renamed num_layers to w.

output comprises all the hidden states in the last layer ("last" depth-wise, not time-wise). (h_n, c_n) comprises the hidden states after the last timestep, t = n, so you could potentially feed them into another LSTM.

LSTM diagram

The batch dimension is not included.

nnnmmm
  • 7,964
  • 4
  • 22
  • 41
  • 4
    Great, thanks, that makes a lot of sense and is really helpful. So that means, for example, that there is no way to get the hidden values for all layers at a time step other than the last one? – N. Virgo Jan 18 '18 at 00:14
  • 15
    Right, unless you have individual LSTMs with `num_layers = 1` that take the previous net's output as input. – nnnmmm Jan 18 '18 at 07:43
  • 5
    @nnnmmm So, each (blue) box is an LSTM/RNN/GRU unit, right? And `h_i` and `c_i` are the *hidden and cell states* resp and `w` is the depth of our network, right? – kmario23 Jan 19 '18 at 00:19
  • 9
    @kmario23: yes, each blue box is an LSTM unit. As I understand, vanilla RNN and GRU don't have cell states, just hidden states, so they would look a little different. You're right about `h_i`, `c_i` and `w`. – nnnmmm Jan 19 '18 at 07:55
  • 2
    there are two units labeled `h_1^(1)`, the right one probably should have `h_2^(1)` – Andre Holzner Jan 20 '18 at 09:05
  • Thanks, I fixed it. – nnnmmm Jan 24 '18 at 11:29
  • @nnnmmm: What did you use to make the diagram, may I ask? – Prakhar Agrawal Apr 16 '18 at 17:21
  • 2
    It was draw.io, but I'm not sure if it's the best tool for the job. – nnnmmm Apr 16 '18 at 17:43
  • 1
    @nnnmmm, is the number of h (depth-wise) is the number of stacked lstm cell? – jted95 Oct 05 '18 at 00:05
  • It would be much appreciated if you add a similar graph for bidirectional LSTM :). – ndrwnaguib Apr 14 '19 at 23:42
  • 7
    This is way clearer than Pytorch's official doc. They should include this pic. So many thanks for this. Amazing. Now I understand exactly what the output means. – SRC May 07 '19 at 20:04
  • So for depth 1 LSTM, the following statement from the question will be correct: *Does h_n just redundantly provide the last time step that's already included in output* – Pulkit Bansal Nov 06 '20 at 11:49
  • The documentation mentions o_t as output gate, and then says here that this output refers to h_t for each time step. This is causing a lot of confusion. – Pulkit Bansal Nov 06 '20 at 11:52
  • I know I'm not supposed to say thank you, but I don't care. This diagram is great. I've been blindly coding LSTMs for a while and this really clarified what's going on under the hood in a way the docs never did. – rocksNwaves Dec 08 '20 at 13:58
  • how does $h_1^1$ combine with $h_2^0$? – Gulzar Mar 10 '21 at 13:19
5

I just verified some of this using code, and its indeed correct that if it's a depth 1 LSTM, then h_n is the same as the last value of the "output". (this will not be true for > 1 depth LSTM though as explained above by @nnnmmm)

So, basically the "output" we get after applying LSTM is not the same as o_t as defined in the documentation, rather it is h_t.

import torch
import torch.nn as nn

torch.manual_seed(0)
model = nn.LSTM( input_size = 1, hidden_size = 50, num_layers  = 1 )
x = torch.rand( 50, 1, 1)
output, (hn, cn) = model(x)

Now one can check that output[-1] and hn both have the same value as follows

tensor([[ 0.1140, -0.0600, -0.0540,  0.1492, -0.0339, -0.0150, -0.0486,  0.0188,
          0.0504,  0.0595, -0.0176, -0.0035,  0.0384, -0.0274,  0.1076,  0.0843,
         -0.0443,  0.0218, -0.0093,  0.0002,  0.1335,  0.0926,  0.0101, -0.1300,
         -0.1141,  0.0072, -0.0142,  0.0018,  0.0071,  0.0247,  0.0262,  0.0109,
          0.0374,  0.0366,  0.0017,  0.0466,  0.0063,  0.0295,  0.0536,  0.0339,
          0.0528, -0.0305,  0.0243, -0.0324,  0.0045, -0.1108, -0.0041, -0.1043,
         -0.0141, -0.1222]], grad_fn=<SelectBackward>)
Pulkit Bansal
  • 1,963
  • 14
  • 12
5

In Pytorch, the output parameter gives the output of each individual LSTM cell in the last layer of the LSTM stack, while hidden state and cell state give the output of each hidden cell and cell state in the LSTM stack in every layer.

import torch.nn as nn
torch.manual_seed(1)
inputs = [torch.randn(1, 3) for _ in range(5)] # indicates that there are 5 sequences to be given as inputs and (1,3) indicates that there is 1 layer with 3 cells
hidden = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3)) #initializing h and c values to be of dimensions (1, 1, 3) which indicates there is (1 * 1) - num_layers * num_directions, with batch size of 1 and projection size of 3. 
                                #Since there is only 1 batch in input, h and c can also have only one batch of data for initialization and the number of cells in both input and output should also match.
 
lstm = nn.LSTM(3, 3) #implying both input and output are 3 dimensional data
for i in inputs:
    out, hidden = lstm(i.view(1, 1, -1), hidden)
    print('out:', out)
    print('hidden:', hidden)

Output

out: tensor([[[-0.1124, -0.0653,  0.2808]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.1124, -0.0653,  0.2808]]], grad_fn=<StackBackward>), tensor([[[-0.2883, -0.2846,  2.0720]]], grad_fn=<StackBackward>))
out: tensor([[[ 0.1675, -0.0376,  0.4402]]], grad_fn=<StackBackward>)
hidden: (tensor([[[ 0.1675, -0.0376,  0.4402]]], grad_fn=<StackBackward>), tensor([[[ 0.4394, -0.1226,  1.5611]]], grad_fn=<StackBackward>))
out: tensor([[[0.3699, 0.0150, 0.1429]]], grad_fn=<StackBackward>)
hidden: (tensor([[[0.3699, 0.0150, 0.1429]]], grad_fn=<StackBackward>), tensor([[[0.8432, 0.0618, 0.9413]]], grad_fn=<StackBackward>))
out: tensor([[[0.1795, 0.0296, 0.2957]]], grad_fn=<StackBackward>)
hidden: (tensor([[[0.1795, 0.0296, 0.2957]]], grad_fn=<StackBackward>), tensor([[[0.4541, 0.1121, 0.9320]]], grad_fn=<StackBackward>))
out: tensor([[[0.1365, 0.0596, 0.3931]]], grad_fn=<StackBackward>)
hidden: (tensor([[[0.1365, 0.0596, 0.3931]]], grad_fn=<StackBackward>), tensor([[[0.3430, 0.1948, 1.0255]]], grad_fn=<StackBackward>))

Multi-Layered LSTM

import torch.nn as nn
torch.manual_seed(1)
num_layers = 2
inputs = [torch.randn(1, 3) for _ in range(5)] 
hidden = (torch.randn(2, 1, 3),
          torch.randn(2, 1, 3))
lstm = nn.LSTM(input_size=3, hidden_size=3, num_layers=2)
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)
    print('out:', out)
    print('hidden:', hidden)

Output

out: tensor([[[-0.0819,  0.1214, -0.2586]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.2625,  0.4415, -0.4917]],

        [[-0.0819,  0.1214, -0.2586]]], grad_fn=<StackBackward>), tensor([[[-2.5740,  0.7832, -0.9211]],

        [[-0.2803,  0.5175, -0.5330]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1298,  0.2797, -0.0882]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.3818,  0.3306, -0.3020]],

        [[-0.1298,  0.2797, -0.0882]]], grad_fn=<StackBackward>), tensor([[[-2.3980,  0.6347, -0.6592]],

        [[-0.3643,  0.9301, -0.1326]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1630,  0.3187,  0.0728]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.5612,  0.3134, -0.0782]],

        [[-0.1630,  0.3187,  0.0728]]], grad_fn=<StackBackward>), tensor([[[-1.7555,  0.6882, -0.3575]],

        [[-0.4571,  1.2094,  0.1061]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1723,  0.3274,  0.1546]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.5112,  0.1597, -0.0901]],

        [[-0.1723,  0.3274,  0.1546]]], grad_fn=<StackBackward>), tensor([[[-1.4417,  0.5892, -0.2489]],

        [[-0.4940,  1.3620,  0.2255]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1847,  0.2968,  0.1333]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.3256,  0.3217, -0.1899]],

        [[-0.1847,  0.2968,  0.1333]]], grad_fn=<StackBackward>), tensor([[[-1.7925,  0.6096, -0.4432]],

        [[-0.5147,  1.4031,  0.2014]]], grad_fn=<StackBackward>))

Bi-Directional Multi-Layered LSTM

import torch.nn as nn
torch.manual_seed(1)
num_layers = 2
is_bidirectional = True
inputs = [torch.randn(1, 3) for _ in range(5)] 
hidden = (torch.randn(4, 1, 3),
          torch.randn(4, 1, 3)) #4 -> (2 * 2) -> num_layers * num_directions
lstm = nn.LSTM(input_size=3, hidden_size=3, num_layers=2, bidirectional=is_bidirectional)

for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)
    print('out:', out)
    print('hidden:', hidden)
    # output dim -> (seq_len, batch, num_directions * hidden_size) -> (5, 1, 2*3)
    # hidden dim -> (num_layers * num_directions, batch, hidden_size) -> (2 * 2, 1, 3)
    # cell state dim -> (num_layers * num_directions, batch, hidden_size) -> (2 * 2, 1, 3)

Output

out: tensor([[[-0.4620,  0.1115, -0.1087,  0.1646,  0.0173, -0.2196]]],
       grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.5187,  0.2656, -0.2543]],

        [[ 0.4175,  0.0539,  0.0633]],

        [[-0.4620,  0.1115, -0.1087]],

        [[ 0.1646,  0.0173, -0.2196]]], grad_fn=<StackBackward>), tensor([[[ 1.1546,  0.4012, -0.4119]],

        [[ 0.7999,  0.2632,  0.2587]],

        [[-1.4196,  0.2075, -0.3148]],

        [[ 0.6605,  0.0243, -0.5783]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1860,  0.1359, -0.2719,  0.0815,  0.0061, -0.0980]]],
       grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.2945,  0.0842, -0.1580]],

        [[ 0.2766, -0.1873,  0.2416]],

        [[-0.1860,  0.1359, -0.2719]],

        [[ 0.0815,  0.0061, -0.0980]]], grad_fn=<StackBackward>), tensor([[[ 0.5453,  0.1281, -0.2497]],

        [[ 0.9706, -0.3592,  0.4834]],

        [[-0.3706,  0.2681, -0.6189]],

        [[ 0.2029,  0.0121, -0.3028]]], grad_fn=<StackBackward>))
out: tensor([[[ 0.1095,  0.1520, -0.3238,  0.0283,  0.0387, -0.0820]]],
       grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.1427,  0.0859, -0.2926]],

        [[ 0.1536, -0.2343,  0.0727]],

        [[ 0.1095,  0.1520, -0.3238]],

        [[ 0.0283,  0.0387, -0.0820]]], grad_fn=<StackBackward>), tensor([[[ 0.2386,  0.1646, -0.4102]],

        [[ 0.2636, -0.4828,  0.1889]],

        [[ 0.1967,  0.2848, -0.7155]],

        [[ 0.0735,  0.0702, -0.2859]]], grad_fn=<StackBackward>))
out: tensor([[[ 0.2346,  0.1576, -0.4006, -0.0053,  0.0256, -0.0653]]],
       grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.1706,  0.0147, -0.0341]],

        [[ 0.1835, -0.3951,  0.2506]],

        [[ 0.2346,  0.1576, -0.4006]],

        [[-0.0053,  0.0256, -0.0653]]], grad_fn=<StackBackward>), tensor([[[ 0.3422,  0.0269, -0.0475]],

        [[ 0.4235, -0.9144,  0.5655]],

        [[ 0.4589,  0.2807, -0.8332]],

        [[-0.0133,  0.0507, -0.1996]]], grad_fn=<StackBackward>))
out: tensor([[[ 0.2774,  0.1639, -0.4460, -0.0228,  0.0086, -0.0369]]],
       grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.2147, -0.0191,  0.0677]],

        [[ 0.2516, -0.4591,  0.3327]],

        [[ 0.2774,  0.1639, -0.4460]],

        [[-0.0228,  0.0086, -0.0369]]], grad_fn=<StackBackward>), tensor([[[ 0.4414, -0.0299,  0.0889]],

        [[ 0.6360, -1.2360,  0.7229]],

        [[ 0.5692,  0.2843, -0.9375]],

        [[-0.0569,  0.0177, -0.1039]]], grad_fn=<StackBackward>))
4

It really depends on a model you use and how you will interpret the model. Output may be:

  • a single LSTM cell hidden state
  • several LSTM cell hidden states
  • all the hidden states outputs

Output, is almost never interpreted directly. If the input is encoded there should be a softmax layer to decode the results.

Note: In language modeling hidden states are used to define the probability of the next word, p(wt+1|w1,...,wt) =softmax(Wht+b).

prosti
  • 42,291
  • 14
  • 186
  • 151
3

The output state is the tensor of all the hidden state from each time step in the RNN(LSTM), and the hidden state returned by the RNN(LSTM) is the last hidden state from the last time step from the input sequence. You could check this by collecting all of the hidden states from each step and comparing that to the output state,(provided you are not using pack_padded_sequence).

Jibin Mathew
  • 4,816
  • 4
  • 40
  • 68