8

I am working on an attention model, and before running the final model, I was going through the tensor shapes which flow through the code. I have an operation where I need to reshape the tensor. The tensor is of the shape torch.Size([[30, 8, 9, 64]]) where 30 is the batch_size, 8 is the number of attention head (this is not relevant to my question) 9 is the number of words in the sentence and 64 is some intermediate embedding representation of the word. I have to reshape the tensor to a size of torch.size([30, 9, 512]) before processing it further. So I was looking into some reference online and they have done the following x.transpose(1, 2).contiguous().view(30, -1, 512) whereas I was thinking that this should work x.transpose(1, 2).reshape(30, -1, 512).

In the first case the grad_fn is <ViewBackward>, whereas in my case it is <UnsafeViewBackward>. Aren't these two the same operations? Will this result in a training error?

abkds
  • 1,764
  • 7
  • 27
  • 43
  • 4
    two operations are different so that you get different `grad_fn`. Visit [here](https://stackoverflow.com/questions/49643225/whats-the-difference-between-reshape-and-view-in-pytorch) for more info. – David Ng Apr 27 '19 at 18:51

1 Answers1

1

Aren't these two the same operations?

No. While they produce effectively the same tensor, the operations are not the same, and they are not guaranteed to have the same storage.

TensorShape.cpp:

// _unsafe_view() differs from view() in that the returned tensor isn't treated
// as a view for the purposes of automatic differentiation. (It's not listed in
// VIEW_FUNCTIONS in gen_autograd.py).  It's only safe to use if the `self` tensor
// is temporary. For example, the viewed tensor here (a + b) is discarded immediately
// after viewing:
//
//  res = at::_unsafe_view(a + b, size);
//
// This is a hack because in-place operations on tensors treated like views
// can be much more expensive than the same operations on non-view tensors.

Note this can produce an error if applied to complex inputs, but this is generally not yet fully supported in PyTorch and not unique to this function.

iacob
  • 20,084
  • 6
  • 92
  • 119