28

Both .flatten() and .view(-1) flatten a tensor in PyTorch. What's the difference?

  1. Does .flatten() copy the data of the tensor?
  2. Is .view(-1) faster?
  3. Is there any situation that .flatten() doesn't work?
iacob
  • 20,084
  • 6
  • 92
  • 119
ipid
  • 560
  • 1
  • 8
  • 13
  • 4
    I think they are identical for the default arguments to `.flatten()`, but `.flatten()` allows you to pass a `start_dim` and an `end_dim` to get more complex behavior. For example, `torch.ones(10, 4, 5, 6).flatten(start_dim=1, end_dim=2)` returns a tensor of shape `(10, 20, 6)`. – adeelh Jul 27 '19 at 20:37

3 Answers3

19

In addition to @adeelh's comment, there is another difference: torch.flatten() results in a .reshape(), and the differences between .reshape() and .view() are:

  • [...] torch.reshape may return a copy or a view of the original tensor. You can not count on that to return a view or a copy.

  • Another difference is that reshape() can operate on both contiguous and non-contiguous tensor while view() can only operate on contiguous tensor. Also see here about the meaning of contiguous

For context:

  • The community requested for a flatten function for a while, and after Issue #7743, the feature was implemented in the PR #8578.

  • You can see the implementation of flatten here, where a call to .reshape() can be seen in return line.

Community
  • 1
  • 1
Berriel
  • 12,659
  • 4
  • 43
  • 67
12

flatten is simply a convenient alias of a common use-case of view.1

There are several others:

Function Equivalent view logic
flatten() view(-1)
flatten(start, end) view(*t.shape[:start], -1, *t.shape[end+1:])
squeeze() view(*[s for s in t.shape if s != 1])
unsqueeze(i) view(*t.shape[:i-1], 1, *t.shape[i:])

Note that flatten allows you to flatten a specific contiguous subset of dimensions, with the start_dim and end_dim arguments.


  1. Actually the superficially equivalent reshape under the hood.
iacob
  • 20,084
  • 6
  • 92
  • 119
2

First of all, .view() works only on contiguous data, while .flatten() works on both contiguous and non contiguous data. Functions like transpose whcih generates non-contiguous data, can be acted upon by .flatten() but not .view().

Coming to copying of data, both .view() and .flatten() does not copy data when it works on contiguous data. However, in case of non-contiguous data, .flatten() first copies data into contiguous memory and then change the dimensions. Any change in the new tensor would not affect th original tensor.

 ten=torch.zeros(2,3)
 ten_view=ten.view(-1)
 ten_view[0]=123
 ten 

>>tensor([[123.,   0.,   0.],
           [  0.,   0.,   0.]])

 ten=torch.zeros(2,3)
 ten_flat=ten.flatten()
 ten_flat[0]=123
 ten

>>tensor([[123.,   0.,   0.],
        [  0.,   0.,   0.]])

In the above code, the tensor ten have contiguous memory allocation. Any changes to ten_view or ten_flat is reflected upon tensor ten

ten=torch.zeros(2,3).transpose(0,1)
ten_flat=ten.flatten()
ten_flat[0]=123
ten

>>tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])

In this case non-contiguous transposed tensor ten is used for flatten(). Any changes made to ten_flat is not reflected upon ten.

TanD
  • 21
  • 2