1

I am trying to compute the total parameters of LSTM model, and I have some confusion.

I have searched some answers, such as this post and this post. I don't know how what's the role of hidden units play in the parameter computation(h1=64, h2=128 in my case).

import tensorflow as tf

b, t, d_in, d_out = 32, 256, 161, 257

data = tf.placeholder("float", [b, t, d_in])  # [batch, timestep, dim_in]
labels = tf.placeholder("float", [b, t, d_out])  # [batch, timestep, dim_out]

myinput = data
batch_size, seq_len, dim_in = myinput.shape

rnn_layers = []

h1 = 64
c1 = tf.nn.rnn_cell.LSTMCell(h1)
rnn_layers.append(c1)

h2 = 128
c2 = tf.nn.rnn_cell.LSTMCell(h1)
rnn_layers.append(c2)

multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
rnnoutput, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell, 
inputs=myinput, dtype=tf.float32)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

all_trainable_vars = tf.reduce_sum([tf.reduce_prod(v.shape) for v in tf.trainable_variables()])
print(sess.run(all_trainable_vars))

I printed the total number of parameters using Tensorflow, it showed the total number of a parameter is 90880. How can I get this result step by step, Thank you

berlloon
  • 41
  • 6

1 Answers1

1

In your case, you defined a LSTM cell via this line c1 = tf.nn.rnn_cell.LSTMCell(h1). To answer your question, here I will introduce the mathematical definition of LSTM. Like the picture (picture source wikipedia-lstm) below,

LSTM

t: means at time t.

  • f_t is named forget gate.
  • i_t is named input gate.
  • o_t is named are called.
  • c_t, h_t are named the cell state and hidden state of the LSTM cell, respectively.

For tf.nn.rnn_cell.LSTMCell(h1), h1=64 is the dimension of h_t, i.e. dim(h_t) = 64.

Community
  • 1
  • 1
guorui
  • 871
  • 2
  • 9
  • 21
  • Thank you so much for your explanation. One more thing, do you know how can I get the final result 90880? I mean, what are the dimensions of W, U, b for the two layers(h1=64, h2=128)? – berlloon Jul 25 '19 at 02:02
  • Sorry, I made a mistake. c2 = tf.nn.rnn_cell.LSTMCell(h1). It should be c2 = tf.nn.rnn_cell.LSTMCell(h2). In my case, for the first later, total parameter: 4 * (64 * (64 + 161) + 64) = 57856, and for the second layer, total parameter: 4 * (128 * (128 + 64) + 128) = 98816. Toally, 57856 + 98816 = 156672, which is as same as what Tensorflow printed. – berlloon Jul 25 '19 at 06:52