Softmax Implementation in PyTorch and Numpy
A Softmax function is defined as follows:

A direct implementation of the above formula is as follows:
def softmax(x):
return np.exp(x) / np.exp(x).sum(axis=0)
Above implementation can run into arithmetic overflow because of np.exp(x)
.
To avoid the overflow, we can divide the numerator and denominator in the softmax equation with a constant C
. Then the softmax function becomes following:

The above approach is implemented in PyTorch and we take log(C)
as -max(x)
. Below is the PyTorch implementation:
def softmax_torch(x): # Assuming x has atleast 2 dimensions
maxes = torch.max(x, 1, keepdim=True)[0]
x_exp = torch.exp(x-maxes)
x_exp_sum = torch.sum(x_exp, 1, keepdim=True)
probs = x_exp/x_exp_sum
return probs
A corresponding Numpy equivalent is as follows:
def softmax_np(x):
maxes = np.max(x, axis=1, keepdims=True)[0]
x_exp = np.exp(x-maxes)
x_exp_sum = np.sum(x_exp, 1, keepdims=True)
probs = x_exp/x_exp_sum
return probs
We can compare the results with PyTorch implementation - torch.nn.functional.softmax
using below snippet:
import torch
import numpy as np
if __name__ == "__main__":
x = torch.randn(1, 3, 5, 10)
std_pytorch_softmax = torch.nn.functional.softmax(x)
pytorch_impl = softmax_torch(x)
numpy_impl = softmax_np(x.detach().cpu().numpy())
print("Shapes: x --> {}, std --> {}, pytorch impl --> {}, numpy impl --> {}".format(x.shape, std_pytorch_softmax.shape, pytorch_impl.shape, numpy_impl.shape))
print("Std and torch implementation are same?", torch.allclose(std_pytorch_softmax, pytorch_impl))
print("Std and numpy implementation are same?", torch.allclose(std_pytorch_softmax, torch.from_numpy(numpy_impl)))
References: