Although the solution of Berriel solves this specific question, I thought adding some explanation might help everyone to shed some light on the trick that's employed here, so that it can be adapted for (m)any other dimensions.
Let's start by inspecting the shape of the input tensor x
:
In [58]: x.shape
Out[58]: torch.Size([3, 2, 2])
So, we have a 3D tensor of shape (3, 2, 2)
. Now, as per OP's question, we need to compute maximum
of the values in the tensor along both 1st and 2nd dimensions. As of this writing, the torch.max()
's dim
argument supports only int
. So, we can't use a tuple. Hence, we will use the following trick, which I will call as,
The Flatten & Max Trick: since we want to compute max
over both 1st and 2nd dimensions, we will flatten both of these dimensions to a single dimension and leave the 0th dimension untouched. This is exactly what is happening by doing:
In [61]: x.flatten().reshape(x.shape[0], -1).shape
Out[61]: torch.Size([3, 4]) # 2*2 = 4
So, now we have shrinked the 3D tensor to a 2D tensor (i.e. matrix).
In [62]: x.flatten().reshape(x.shape[0], -1)
Out[62]:
tensor([[-0.3000, -0.2926, -0.2705, -0.2632],
[-0.1821, -0.1747, -0.1526, -0.1453],
[-0.0642, -0.0568, -0.0347, -0.0274]])
Now, we can simply apply max
over the 1st dimension (i.e. in this case, first dimension is also the last dimension), since the flattened dimensions resides in that dimension.
In [65]: x.flatten().reshape(x.shape[0], -1).max(dim=1) # or: `dim = -1`
Out[65]:
torch.return_types.max(
values=tensor([-0.2632, -0.1453, -0.0274]),
indices=tensor([3, 3, 3]))
We got 3 values in the resultant tensor since we had 3 rows in the matrix.
Now, on the other hand if you want to compute max
over 0th and 1st dimensions, you'd do:
In [80]: x.flatten().reshape(-1, x.shape[-1]).shape
Out[80]: torch.Size([6, 2]) # 3*2 = 6
In [79]: x.flatten().reshape(-1, x.shape[-1])
Out[79]:
tensor([[-0.3000, -0.2926],
[-0.2705, -0.2632],
[-0.1821, -0.1747],
[-0.1526, -0.1453],
[-0.0642, -0.0568],
[-0.0347, -0.0274]])
Now, we can simply apply max
over the 0th dimension since that is the result of our flattening. ((also, from our original shape of (3, 2, 2
), after taking max over first 2 dimensions, we should get two values as result.)
In [82]: x.flatten().reshape(-1, x.shape[-1]).max(dim=0)
Out[82]:
torch.return_types.max(
values=tensor([-0.0347, -0.0274]),
indices=tensor([5, 5]))
In a similar vein, you can adapt this approach to multiple dimensions and other reduction functions such as min
.
Note: I'm following the terminology of 0-based dimensions (0, 1, 2, 3, ...
) just to be consistent with PyTorch usage and the code.