I do not understand how to evaluate this expression:
x.view(*(x.shape[:-2]),-1).mean(-1)`,
if x.shape == (N, C, H, W)
.
What does the asterisk *
stand for? And what is mean(-1)
?
I do not understand how to evaluate this expression:
x.view(*(x.shape[:-2]),-1).mean(-1)`,
if x.shape == (N, C, H, W)
.
What does the asterisk *
stand for? And what is mean(-1)
?
What is
*
?
For .view()
pytorch expects the new shape to be provided by individual int arguments (represented in the doc as *shape
). The asterisk (*
) can be used in python to unpack a list into its individual elements, thus passing to view
the correct form of input arguments it expects.
So, in your case, x.shape
is (N, C, H, W)
, if you were to pass x.shape[:-2]
without the asterisk, you would get x.view((N, C), -1)
- which is not what view()
expects. Unpacking (N, C)
using the asterisk results with view
receiving view(N, C, -1)
arguments as it expects. The resulting shape is (N, C, H*W)
(a 3D tensor instead of 4).
What is
mean(-1)
?
Simply look at the documentation of .mean()
: the first argument is a dim
argument. That is x.mean(-1)
applies mean
along the last dimension. In your case, since keepdim=False
by default, your output will be a (N, C)
sized tensor where each element correspond to the mean value along both spatial dimensions.
This is equivalent to
x.mean(-1).mean(-1)