3

In cs231n 2017 class, when we backpropagate the gradient we update the biases like this:

db = np.sum(dscores, axis=0, keepdims=True)

What's the basic idea behind the sum operation? Thanks

Maxim
  • 52,561
  • 27
  • 155
  • 209
Woody. Wang
  • 135
  • 1
  • 9
  • In effect, I'm afraid I didn't make it clear. What I wanna know is why do we need to sum each row but not to calculate the average or something else? – Woody. Wang Nov 09 '17 at 04:17

2 Answers2

1

This is the formula of derivative (more precisely gradient) of the loss function with respect to the bias (see this question and this post for derivation details).

The numpy.sum call computes the per-column sums along the 0 axis. Example:

dscores = np.array([[1, 2, 3],[2, 3, 4]])    # a 2D matrix
db = np.sum(dscores, axis=0, keepdims=True)  # result: [[3 5 7]]

The result is exactly element-wise sum [1, 2, 3] + [2, 3, 4] = [3 5 7]. In addition, keepdims=True preserves the rank of original matrix, that's why the result is [[3 5 7]] instead of just [3 5 7].

By the way, if we were to compute np.sum(dscores, axis=1, keepdims=True), the result would be [[6] [9]].

[Update]

Apparently, the focus of this question is the formula itself. I'd like not to go too much off-topic here and just try to tell the main idea. The sum appears in the formula because of broadcasting over the mini-batch in the forward pass. If you take just one example at a time, the bias derivative is just the error signal, i.e. dscores (see the links above explain it in detail). But for a batch of examples the gradients are added up due to linearity. That's why we take the sum along the batch axis=0.

Maxim
  • 52,561
  • 27
  • 155
  • 209
  • Thank you very much. But I guess I didn't explain my question clear enough. What I wanna know is why the bias is the sum of each row? – Woody. Wang Nov 09 '17 at 04:11
  • @Woody.Wang Ah, got you. I've included the main idea behind this formula in the answer, but there's no better way to understand it other than derive it yourself, so I'd strongly recommend you do this. – Maxim Nov 09 '17 at 09:09
-1

Numpy axis visual description:

Numpy axis

Debashis Sahoo
  • 5,388
  • 5
  • 36
  • 41