0

When using nn.softmax(), we use dim=1 or 0. Here dim=0 should mean row according to intuition but seems it means along the column. Is this true?

>>> x = torch.tensor([[1,2],[3,4]],dtype=torch.float)
>>> F.softmax(x,dim=0)
tensor([[0.1192, 0.1192],
        [0.8808, 0.8808]])
>>> F.softmax(x,dim=1)
tensor([[0.2689, 0.7311],
        [0.2689, 0.7311]])

Here when dim=0, probabilities along the columns sum to 1. Similarly when dim=1 probabilities along the rows sum to 1. Can someone explain how dim is used in PyTorch?

Berriel
  • 12,659
  • 4
  • 43
  • 67
Dhruv Vashist
  • 109
  • 1
  • 7
  • 1
    Does this answer your question? [What does axis in pandas mean?](https://stackoverflow.com/questions/22149584/what-does-axis-in-pandas-mean) – kHarshit Sep 17 '21 at 12:04

1 Answers1

4

Indeed, in the 2D case: row refers to axis=0, while column refers to axis=1.

The dim option specifies along which dimension the softmax is apply, i.e. summing back on that same axis will lead to 1s:

>>> x = torch.arange(1, 7, dtype=float).reshape(2,3)
tensor([[1., 2., 3.],
        [4., 5., 6.]], dtype=torch.float64)

On axis=0:

>>> F.softmax(x, dim=0).sum(0)
tensor([1.0000, 1.0000, 1.0000], dtype=torch.float64)

On axis=1:

>>> F.softmax(x, dim=1).sum(1)
>>> tensor([1.0000, 1.0000], dtype=torch.float64)

This is the expected behavior for torch.nn.functional.softmax

[...] Parameters:

  • dim (int) – A dimension along which Softmax will be computed (so every slice along dim will sum to 1).
Ivan
  • 34,531
  • 8
  • 55
  • 100