Both .flatten()
and .view(-1)
flatten a tensor in PyTorch. What's the difference?
- Does
.flatten()
copy the data of the tensor? - Is
.view(-1)
faster? - Is there any situation that
.flatten()
doesn't work?
Both .flatten()
and .view(-1)
flatten a tensor in PyTorch. What's the difference?
.flatten()
copy the data of the tensor?.view(-1)
faster?.flatten()
doesn't work?In addition to @adeelh's comment, there is another difference: torch.flatten()
results in a .reshape()
, and the differences between .reshape()
and .view()
are:
[...]
torch.reshape
may return a copy or a view of the original tensor. You can not count on that to return a view or a copy.Another difference is that reshape() can operate on both contiguous and non-contiguous tensor while view() can only operate on contiguous tensor. Also see here about the meaning of contiguous
For context:
The community requested for a flatten
function for a while, and after Issue #7743, the feature was implemented in the PR #8578.
You can see the implementation of flatten here, where a call to .reshape()
can be seen in return
line.
flatten
is simply a convenient alias of a common use-case of view
.1
There are several others:
Function | Equivalent view logic |
---|---|
flatten() |
view(-1) |
flatten(start, end) |
view(*t.shape[:start], -1, *t.shape[end+1:]) |
squeeze() |
view(*[s for s in t.shape if s != 1]) |
unsqueeze(i) |
view(*t.shape[:i-1], 1, *t.shape[i:]) |
Note that flatten
allows you to flatten a specific contiguous subset of dimensions, with the start_dim
and end_dim
arguments.
reshape
under the hood.First of all, .view()
works only on contiguous data, while .flatten()
works on both contiguous and non contiguous data. Functions like transpose whcih generates non-contiguous data, can be acted upon by .flatten()
but not .view()
.
Coming to copying of data, both .view()
and .flatten()
does not copy data when it works on contiguous data. However, in case of non-contiguous data, .flatten()
first copies data into contiguous memory and then change the dimensions. Any change in the new tensor would not affect th original tensor.
ten=torch.zeros(2,3)
ten_view=ten.view(-1)
ten_view[0]=123
ten
>>tensor([[123., 0., 0.],
[ 0., 0., 0.]])
ten=torch.zeros(2,3)
ten_flat=ten.flatten()
ten_flat[0]=123
ten
>>tensor([[123., 0., 0.],
[ 0., 0., 0.]])
In the above code, the tensor ten have contiguous memory allocation. Any changes to ten_view or ten_flat is reflected upon tensor ten
ten=torch.zeros(2,3).transpose(0,1)
ten_flat=ten.flatten()
ten_flat[0]=123
ten
>>tensor([[0., 0.],
[0., 0.],
[0., 0.]])
In this case non-contiguous transposed tensor ten is used for flatten(). Any changes made to ten_flat is not reflected upon ten.