3

The following network code, which should be your classic simple LSTM language model, starts outputting nan loss after a while... on my training set it takes a couple of hours and I couldn't replicate it easily on smaller datasets. But it always happens in serious training.

Sparse_softmax_with_cross_entropy should be numerically stable, so it can't be the cause... but other than that, I don't see any other node that could cause an issue in the graph. What could be the problem?

class MyLM():
    def __init__(self, batch_size, embedding_size, hidden_size, vocab_size):
        self.x = tf.placeholder(tf.int32, [batch_size, None])  # [batch_size, seq-len]
        self.lengths = tf.placeholder(tf.int32, [batch_size])  # [batch_size]

        # remove padding. [batch_size * seq_len] -> [batch_size * sum(lengths)]
        mask = tf.sequence_mask(self.lengths)  # [batch_size, seq_len]
        mask = tf.cast(mask, tf.int32)  # [batch_size, seq_len]
        mask = tf.reshape(mask, [-1])  # [batch_size * seq_len]

        # remove padding + last token. [batch_size * seq_len] -> [batch_size * sum(lengths-1)]
        mask_m1 = tf.cast(tf.sequence_mask(self.lengths - 1, maxlen=tf.reduce_max(self.lengths)), tf.int32)  # [batch_size, seq_len]
        mask_m1 = tf.reshape(mask_m1, [-1])  # [batch_size * seq_len]

        # remove padding + first token.  [batch_size * seq_len] -> [batch_size * sum(lengths-1)]
        m1_mask = tf.cast(tf.sequence_mask(self.lengths - 1), tf.int32)  # [batch_size, seq_len-1]
        m1_mask = tf.concat([tf.cast(tf.zeros([batch_size, 1]), tf.int32), m1_mask], axis=1)  # [batch_size, seq_len]
        m1_mask = tf.reshape(m1_mask, [-1])  # [batch_size * seq_len]

        embedding = tf.get_variable("TokenEmbedding", shape=[vocab_size, embedding_size])
        x_embed = tf.nn.embedding_lookup(embedding, self.x)  # [batch_size, seq_len, embedding_size]

        lstm = tf.nn.rnn_cell.LSTMCell(hidden_size, use_peepholes=True)

        # outputs shape: [batch_size, seq_len, hidden_size]
        outputs, final_state = tf.nn.dynamic_rnn(lstm, x_embed, dtype=tf.float32,
                                                 sequence_length=self.lengths)
        outputs = tf.reshape(outputs, [-1, hidden_size])  # [batch_size * seq_len, hidden_size]

        w = tf.get_variable("w_out", shape=[hidden_size, vocab_size])
        b = tf.get_variable("b_out", shape=[vocab_size])
        logits_padded = tf.matmul(outputs, w) + b  # [batch_size * seq_len, vocab_size]
        self.logits = tf.dynamic_partition(logits_padded, mask_m1, 2)[1]  # [batch_size * sum(lengths-1), vocab_size]

        predict = tf.argmax(logits_padded, axis=1)  # [batch_size * seq_len]
        self.predict = tf.dynamic_partition(predict, mask, 2)[1]  # [batch_size * sum(lengths)]

        flat_y = tf.dynamic_partition(tf.reshape(self.x, [-1]), m1_mask, 2)[1]  # [batch_size * sum(lengths-1)]

        self.cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=flat_y)
        self.cost = tf.reduce_mean(self.cross_entropy)
        self.train_step = tf.train.AdamOptimizer(learning_rate=0.01).minimize(self.cost)
  • Does it directly go from having reasonable loss values to all of a sudden being NaN or is there a gradual increase in loss until eventually it gets out of control? – Aaron Aug 25 '17 at 20:43
  • The loss hovers at around 2 and then it becomes suddenly NaN. – Chum-Chum Scarecrows Aug 26 '17 at 04:33
  • Something I've done when debugging this sort of thing in the past is to make sure and exit the training loop as soon as the first NaN occurs. Then look at whatever data was in the last mini-batch and see if there are any anomalies. For instance, there could be a sequence of length zero that is screwing things up. – Aaron Aug 26 '17 at 06:00

2 Answers2

6

check your columns which are fed to the model, in my case, there was a column having NaN values, after removing NaNs, it worked

Rahul Sood
  • 109
  • 1
  • 3
5

It may be the case of exploding gradients, where gradients may explode during backpropagation in LSTMs, resulting number overflows. A common technique to deal with exploding gradients is to perform Gradient Clipping.

Vijay Mariappan
  • 16,921
  • 3
  • 40
  • 59
  • Thanks for this answer. I chose to remedy the issue by initializaing the LSTM kernel with a very small value (`1.e-10`). Will have to see if this doesn't mess things up elsehwhere... – KeithWM Jan 13 '19 at 23:20