0

I'm trying to build a convolutional lstm autoencoder (that also predicts future and past) with Tensorflow, and it works to a certain degree, but the error sometimes jumps back up, so essentially, it never converges.

The model is as follows:

The encoder starts with a 64x64 frame from a 20 frame bouncing mnist video for each time step of the lstm. Every stacking layer of LSTM halfs it and increases the depth via 2x2 convolutions with a stride of 2. (so -->32x32x3 -->...--> 1x1x96) On the other hand, the lstm performs 3x3 convolutions with a stride of 1 on its state. Both results are concatenated to form the new state. In the same way, the decoder uses transposed convolutions to go back to the original format. Then the squared error is calculated.

The error starts at around 2700 and it takes around 20 hours (geforce1060) to get down to ~1700. At which point the jumping back up (and it sometimes jumps back up to 2300 or even ridiculous values like 440300) happens often enough that I can't really get any lower. Also at that point, it can usually pinpoint where the number should be, but its too fuzzy to actually make out the digit...

I tried different learning rates and optimizers, so if anybody knows why that jumping happens, that'd make me happy :)

Here is a graph of the loss with epochs: enter image description here

import tensorflow as tf
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

#based on code by loliverhennigh (Github)
class ConvCell(tf.contrib.rnn.RNNCell):
    count = 0   #exists only to remove issues with variable scope
    def __init__(self, shape, num_features, transpose = False):
        self.shape = shape 
        self.num_features = num_features
        self._state_is_tuple = True
        self._transpose = transpose
        ConvCell.count+=1
        self.count = ConvCell.count

    @property
    def state_size(self):
        return (tf.contrib.rnn.LSTMStateTuple(self.shape[0:4],self.shape[0:4]))

    @property
    def output_size(self):
        return tf.TensorShape(self.shape[1:4])

#here comes to the actual conv lstm implementation, if transpose = true, it performs a deconvolution on the input
    def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope or type(self).__name__+str(self.count)): 
            c, h = state
            state_shape = h.shape
            input_shape = inputs.shape

            #filter variables and convolutions on data coming from the same cell, a time step previous
            h_filters = tf.get_variable("h_filters",[3,3,state_shape[3],self.num_features])
            h_filters_gates = tf.get_variable("h_filters_gates",[3,3,state_shape[3],3])
            h_partial = tf.nn.conv2d(h,h_filters,[1,1,1,1],'SAME')
            h_partial_gates = tf.nn.conv2d(h,h_filters_gates,[1,1,1,1],'SAME')

            c_filters = tf.get_variable("c_filters",[3,3,state_shape[3],3])
            c_partial = tf.nn.conv2d(c,c_filters,[1,1,1,1],'SAME')

            #filters and convolutions/deconvolutions on data coming fromthe cell input
            if self._transpose:
                x_filters = tf.get_variable("x_filters",[2,2,self.num_features,input_shape[3]])
                x_filters_gates = tf.get_variable("x_filters_gates",[2,2,3,input_shape[3]])
                x_partial = tf.nn.conv2d_transpose(inputs,x_filters,[int(state_shape[0]),int(state_shape[1]),int(state_shape[2]),self.num_features],[1,2,2,1],'VALID')
                x_partial_gates = tf.nn.conv2d_transpose(inputs,x_filters_gates,[int(state_shape[0]),int(state_shape[1]),int(state_shape[2]),3],[1,2,2,1],'VALID')
            else:
                x_filters = tf.get_variable("x_filters",[2,2,input_shape[3],self.num_features])
                x_filters_gates = tf.get_variable("x_filters_gates",[2,2,input_shape[3],3])
                x_partial = tf.nn.conv2d(inputs,x_filters,[1,2,2,1],'VALID')
                x_partial_gates = tf.nn.conv2d(inputs,x_filters_gates,[1,2,2,1],'VALID')

            #some more lstm gate business
            gate_bias = tf.get_variable("gate_bias",[1,1,1,3])
            h_bias = tf.get_variable("h_bias",[1,1,1,self.num_features*2])

            gates = h_partial_gates + x_partial_gates + c_partial + gate_bias

            i,f,o = tf.split(gates,3,axis=3)

            #concatenate the units coming from the spacial and the temporal dimension to build a unified state
            concat = tf.concat([h_partial,x_partial],3) + h_bias

            new_c = tf.nn.relu(concat)*tf.sigmoid(i)+c*tf.sigmoid(f)
            new_h = new_c * tf.sigmoid(o)

            new_state = tf.contrib.rnn.LSTMStateTuple(new_c,new_h)
            return new_h, new_state #its redundant, but thats how tensorflow likes it, apparently


#global variables               
LEARNING_RATE = 0.005
ITERATIONS_PER_EPOCH = 80
BATCH_SIZE = 75
TEST = False    #manual switch to go from training to testing

if TEST:
    BATCH_SIZE = 1

inputs  = tf.placeholder(tf.float32, (20, BATCH_SIZE, 64, 64,1))    


shape0 = [BATCH_SIZE,64,64,2]
shape1 = [BATCH_SIZE,32,32,6]
shape2 = [BATCH_SIZE,16,16,12]
shape3 = [BATCH_SIZE,8,8,24]
shape4 = [BATCH_SIZE,4,4,48]
shape5 = [BATCH_SIZE,2,2,96]
shape6 = [BATCH_SIZE,1,1,192]

#apparently tf.multirnncell has very specific requirements for the initial states oO
initial_state1 = (tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape1),tf.zeros(shape1)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape2),tf.zeros(shape2)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape3),tf.zeros(shape3)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape4),tf.zeros(shape4)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape5),tf.zeros(shape5)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape6),tf.zeros(shape6)))
initial_state2 = (tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape5),tf.zeros(shape5)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape4),tf.zeros(shape4)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape3),tf.zeros(shape3)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape2),tf.zeros(shape2)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape1),tf.zeros(shape1)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape0),tf.zeros(shape0)))

#encoding part of the autoencoder graph
cell1 = ConvCell(shape1,3)
cell2 = ConvCell(shape2,6)
cell3 = ConvCell(shape3,12)
cell4 = ConvCell(shape4,24)
cell5 = ConvCell(shape5,48)
cell6 = ConvCell(shape6,96)

mcell = tf.contrib.rnn.MultiRNNCell([cell1,cell2,cell3,cell4,cell5,cell6])

rnn_outputs, rnn_states = tf.nn.dynamic_rnn(mcell, inputs[0:20,:,:,:],initial_state=initial_state1,dtype=tf.float32, time_major=True)


#decoding part of the autoencoder graph, forward block and backwards block
cell9a = ConvCell(shape5,48,transpose = True)
cell10a = ConvCell(shape4,24,transpose = True)
cell11a = ConvCell(shape3,12,transpose = True)
cell12a = ConvCell(shape2,6,transpose = True)
cell13a = ConvCell(shape1,3,transpose = True)
cell14a = ConvCell(shape0,1,transpose = True)

mcella = tf.contrib.rnn.MultiRNNCell([cell9a,cell10a,cell11a,cell12a,cell13a,cell14a])

cell9b = ConvCell(shape5,48,transpose = True)
cell10b = ConvCell(shape4,24,transpose = True)
cell11b= ConvCell(shape3,12,transpose = True)
cell12b = ConvCell(shape2,6,transpose = True)
cell13b = ConvCell(shape1,3,transpose = True)
cell14b = ConvCell(shape0,1,transpose = True)

mcellb = tf.contrib.rnn.MultiRNNCell([cell9b,cell10b,cell11b,cell12b,cell13b,cell14b])

def PredictionLayer(rnn_outputs,viewPoint = 11, reverse = False):

    predLength = viewPoint-2 if reverse else 20-viewPoint   #vision is the input for the decoder
    vision = tf.concat([rnn_outputs[viewPoint-1:viewPoint,:,:,:],tf.zeros([predLength,BATCH_SIZE,1,1,192])],0)

    if reverse:
        rnn_outputs2, rnn_states = tf.nn.dynamic_rnn(mcellb, vision, initial_state = initial_state2, time_major=True)
    else:
        rnn_outputs2, rnn_states = tf.nn.dynamic_rnn(mcella, vision, initial_state = initial_state2, time_major=True)


    mean = tf.reduce_mean(rnn_outputs2,4)

    if TEST:
        return mean

    if reverse:
        return tf.reduce_sum(tf.square(mean-inputs[viewPoint-2::-1,:,:,:,0]))
    else:
        return tf.reduce_sum(tf.square(mean-inputs[viewPoint-1:20,:,:,:,0]))



if TEST:
    mean = tf.concat([PredictionLayer(rnn_outputs,11,True)[::-1,:,:,:],createPredictionLayer(rnn_outputs,11)],0)
else:   #training part of the graph
    error = tf.zeros([1])
    for i in range(8,15): #range size of 7 or less works, 9 or more does not, no idea why
        error += PredictionLayer(rnn_outputs, i)
        error += PredictionLayer(rnn_outputs, i, True)

    train_fn = tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE).minimize(error)



################################################################################
##                           TRAINING LOOP                                    ##
################################################################################
#code based on siemanko/tf_lstm.py (Github)

gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
saver = tf.train.Saver(restore_sequentially=True, allow_empty=True,)
session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
session.run(tf.global_variables_initializer())
vids = np.load("mnist_test_seq.npy") #20/10000/64/64 , moving mnist dataset from http://www.cs.toronto.edu/~nitish/unsupervised_video/
vids = vids[:,0:6000,:,:]   #training set
saver.restore(session,tf.train.latest_checkpoint('./conv_lstm_multiples_v2/'))
#saver.restore(session,'.\conv_lstm_multiples\iteration-74')


for epoch in range(1000):
    if TEST:
        break
    epoch_error = 0

    #randomize batches each epoch
    vids = np.swapaxes(vids,0,1)
    np.random.shuffle(vids)
    vids = np.swapaxes(vids,0,1)


    for i in range(ITERATIONS_PER_EPOCH):
        #running the graph and feeding data
        err,_ = session.run([error, train_fn], {inputs: np.expand_dims(vids[:,i*BATCH_SIZE:(i+1)*BATCH_SIZE,:,:],axis=4)})

        print(err)
        epoch_error += err

    #training error each epoch and regular saving
    epoch_error /= (ITERATIONS_PER_EPOCH*BATCH_SIZE*4096*20*7)
    if (epoch+1) % 5 == 0:
        saver.save(session,'.\conv_lstm_multiples_v2\iteration',global_step=epoch)
        print("saved")
    print("Epoch %d, train error: %f" % (epoch, epoch_error))

#testing
plt.ion()
f, axarr = plt.subplots(2)
vids = np.load("mnist_test_seq.npy")

for i in range(6000,10000):
    img = session.run([mean], {inputs: np.expand_dims(vids[:,i:i+1,:,:],axis=4)})
    for j in range(20):
        axarr[0].imshow(img[0][j,0,:,:])
        axarr[1].imshow(vids[j,i,:,:])
        plt.show()
        plt.pause(0.1)
Xethoras
  • 26
  • 5

1 Answers1

1

Usually this happens when gradients' magnitude is very high at some point and causes your network parameters to change a lot. To verify that it is indeed the case, you can produce the same plot of gradient magnitudes and see if they jump right before the loss jump. Assuming this is the case, the classic approach is to use gradient clipping (or go all the way to natural gradient).

iga
  • 3,571
  • 1
  • 12
  • 22