4

I've a problem for resuming training after saving my model. The problem is that my loss decrease form 6 to 3 for example. At this time I save the model. When I restore it and continue training, the loss restart from 6. It seems that the restoration doesn't really work. I don't understand because printing the weights, it seems that they are loaded properly. I use an ADAM optimizer. Thanks in advance. Here:

    batch_size = self.batch_size 
    num_classes = self.num_classes

    n_hidden = 50 #700 
    n_layers = 1 #3
    truncated_backprop = self.seq_len 
    dropout = 0.3 
    learning_rate = 0.001
    epochs = 200

    with tf.name_scope('input'):
        x = tf.placeholder(tf.float32, [batch_size, truncated_backprop], name='x')
        y = tf.placeholder(tf.int32, [batch_size, truncated_backprop], name='y')

    with tf.name_scope('weights'):
        W = tf.Variable(np.random.rand(n_hidden, num_classes), dtype=tf.float32)
        b = tf.Variable(np.random.rand(1, num_classes), dtype=tf.float32)

    inputs_series = tf.split(x, truncated_backprop, 1)
    labels_series = tf.unstack(y, axis=1)

    with tf.name_scope('LSTM'):
        cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, state_is_tuple=True)
        cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=dropout)
        cell = tf.contrib.rnn.MultiRNNCell([cell] * n_layers)

    states_series, current_state = tf.contrib.rnn.static_rnn(cell, inputs_series, \
        dtype=tf.float32)

    logits_series = [tf.matmul(state, W) + b for state in states_series]
    prediction_series = [tf.nn.softmax(logits) for logits in logits_series]

    losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) \
        for logits, labels, in zip(logits_series, labels_series)]
    total_loss = tf.reduce_mean(losses)

    train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)

    tf.summary.scalar('total_loss', total_loss)
    summary_op = tf.summary.merge_all()

    loss_list = []
    writer = tf.summary.FileWriter('tf_logs', graph=tf.get_default_graph())

    all_saver = tf.train.Saver()

    with tf.Session() as sess:
        #sess.run(tf.global_variables_initializer())
        tf.reset_default_graph()
        saver = tf.train.import_meta_graph('./models/tf_models/rnn_model.meta')
        saver.restore(sess, './models/tf_models/rnn_model')

        for epoch_idx in range(epochs):
            xx, yy = next(self.get_batch)
            batch_count = len(self.D.chars) // batch_size // truncated_backprop

            for batch_idx in range(batch_count):
                batchX, batchY = next(self.get_batch)

                summ, _total_loss, _train_step, _current_state, _prediction_series = sess.run(\
                    [summary_op, total_loss, train_step, current_state, prediction_series],
                    feed_dict = {
                        x : batchX,
                        y : batchY
                    })

                loss_list.append(_total_loss)
                writer.add_summary(summ, epoch_idx * batch_count + batch_idx)
                if batch_idx % 5 == 0:
                    print('Step', batch_idx, 'Batch_loss', _total_loss)

                if batch_idx % 50 == 0:
                    all_saver.save(sess, 'models/tf_models/rnn_model')

            if epoch_idx % 5 == 0:
                print('Epoch', epoch_idx, 'Last_loss', loss_list[-1])
JimZer
  • 918
  • 2
  • 9
  • 19
  • Well, weights are properly restored, but what about data? Is it the same? – Dmitriy Danevskiy Apr 12 '17 at 13:24
  • @DanevskyiDmytro my data come in batches. The sequence of retrieval of the batch is random, but the loss was near 3 for all the data set (an entire epoch). So i would expect that when I restore the loss would restart from near 3 for any batches ? – JimZer Apr 12 '17 at 13:44
  • Could you limit your dataset to a few batches and perform train and test on them? – Dmitriy Danevskiy Apr 12 '17 at 15:03

2 Answers2

1

I had the same problem, in my case, the model was being correctly restored but the loss was starting really high again and again, the problem was that my batch retreival was not random. I had three classes, A, B and C. My data was being fed in this manner A, then B, then C. I don't know if that is your problem but you must ensure that every batch you give to your model has all of your classes in it, so in your case, the batch must have batch_size/num_classes input per class. I changed it and everything worked perfectly :)

Check out if you are correctly feeding your model.

0

My problem was a code error in labels, they were changing between two run. So it works now. Thank you for the help

JimZer
  • 918
  • 2
  • 9
  • 19