This is to be expected. Firstly, the forward is also a lot slower: with your for
loop, Python dispatches batch_size
times the following requests to PyTorch:
- fetch
i
th element of x
- add 1
- update
i
th element of x
with the incremented value
Python is slow. In version two, Python dispatches a single message "add 1 everywhere" to PyTorch. PyTorch is much faster than Python (let alone GPU acceleration it's capable of). This is thanks to the technique called vectorization and is not specific to PyTorch, but essentially all Python (and many other) math packages.
Secondly, for your backward, PyTorch needs to keep track of all operations which happened to x
and backpropagate through them. In the first case, there's batch_size
of them, in the second, just one. Again, vectorization wins.