1

I created my own network model and used data with dimension [batchez, 10, 8] to trained this model.

After that, I wanted to use tensor with dimension [1, 40, 8] to run net.predict(x), ('x' is the name of the tensor), but I got the error:

Input shape axis 1 must equal 10, got shape [1,40,8]

In my opinion, the axis 1 only affects the number of calls to the LSTMCell. Why should it be the same as 10? How can I deal with this problem?

At the same time, I also created a variable in the network to determine whether it is training, because I want to call LSTMCell only once to obtain the output result while it is not training. However, it seems that with the problem above, I can't achieve this goal.

So please help me.

Here is the code.

class lstm_rnn(keras.Model):
    def __init__(self, units):

        super(lstm_rnn, self).__init__()
        
  
        self.state0 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
        self.state1 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
        
     
        self.lstm_cell0 = layers.LSTMCell(units, dropout = 0.5)
        self.lstm_cell1 = layers.LSTMCell(units, dropout = 0.5)
    
    def call(self, inputs):
        x = inputs
        real_out = 0
        
  
        axis_size = x.shape[1]
        is_training = True
        if(axis_size == 2):
            is_training = False
        
        print(inputs.shape)
        
       
        if(is_training):
            state0 = self.state0
            state1 = self.state1
            step_cnt = 0
            
            for word in tf.unstack(x, axis = 1):
                out0, state0 = self.lstm_cell0(word, state0)
                out1, state1 = self.lstm_cell1(out0, state1)
                if(step_cnt == 0):
                    real_out = tf.reshape(out1, shape = (20, 1 ,6))
                else:
                    real_out = tf.concat([real_out, tf.reshape(out1, shape = (20, 1, 6))], axis = 1)
                step_cnt = step_cnt + 1
      
        else:
            state0 = [inputs[0], inputs[0]]
            state1 = [inputs[0], inputs[0]]
            info = inputs[1]
            out0, state0 = self.lstm_cell0(info, state0)
            out1, state1 = self.lstm_cell1(out0, state1)
            real_out = out1
        
        return real_out
tHeSiD
  • 4,587
  • 4
  • 29
  • 49
Rbin
  • 11
  • 1
  • 1
    Question looks fine, edited it to be more readable instead of being a continuous paragraph. – tHeSiD Jun 28 '20 at 18:55
  • You can understand the input shapes pf the LSTM cell from this post https://stackoverflow.com/questions/39324520/ –  Dec 10 '20 at 16:36

0 Answers0