7

I know how to make softmax stable by adding to element -max _i x_i. This avoids overflow and underflow. Now, taking log of this can cause underflow. log softmax(x) can evaluate to zero, leading to -infinity.

I am not sure how to fix it. I know this is a common problem. I read several answers on it, which I didn't understand. But I am still confused on how to solve this problem.

PS: If you provide a simple example, it would be awesome.

Abhishek Bhatia
  • 9,404
  • 26
  • 87
  • 142

4 Answers4

13

In order to stabilize Logsoftmax, most implementations such as Tensorflow and Thenao, use a trick which takes out the largest component max(x_i). This trick is often used for stably computing softmax. For logsoftmax, we begin with:

formula

After extracting out the exp(b) and using the fact that log(exp(x)) = x, we have:

formula

If we set b = max(x_i), this new equation has both overflow and underflow stability conditions.


In terms of code, if x is a vector:

def log_softmax(x):
    x_off = x - np.max(x)
    return x_off - np.log(np.sum(np.exp(x_off)))

See also: https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/

Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135
e3oroush
  • 3,045
  • 1
  • 17
  • 26
1
logsoftmax = logits - log(reduce_sum(exp(logits), dim))

refer: https://www.tensorflow.org/api_docs/python/tf/nn/log_softmax

Nemo
  • 67
  • 5
0

Just use this as it take care of Nan

tf.nn.softmax_cross_entropy_with_logits(
    labels, logits, axis=-1, name=None
)
logits = tf.constant([[4, 5, 1000]], dtype = tf.float32)
labels = tf.constant([[1,0,1]], dtype = tf.float32)

# Case-1 
output = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
print(output) 
>>> tf.Tensor([996.], shape=(1,), dtype=float32)

#Case-2
a = tf.nn.softmax(logits)
output = tf.reduce_sum(-(labels * tf.math.log(a)))
print(output) 
>>> tf.Tensor(nan, shape=(), dtype=float32)


# this happens because value of softmax truncates to zero

print(a) 
>>> <tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[0., 0., 1.]], dtype=float32)>
Josef
  • 2,869
  • 2
  • 22
  • 23
-1

Mathematical tricks cannot help you create log 0 be something other that -inf. If you think it trough, the only way is you normalize the data so that you don't end in there.

prosti
  • 42,291
  • 14
  • 186
  • 151