0

Q1.

I'm trying to make my custom autograd function with pytorch.

But I had a problem with making analytical back propagation with y = x / sum(x, dim=0)

where size of tensor x is (Height, Width) (x is 2-dimensional).

Here's my code

class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
  ctx.save_for_backward(input)
  input = input / torch.sum(input, dim=0)

  return input

@staticmethod
def backward(ctx, grad_output):
  input = ctx.saved_tensors[0]
  H, W = input.size()
  sum = torch.sum(input, dim=0)
  grad_input = grad_output * (1/sum - input*1/sum**2)

  return grad_input

I used (torch.autograd import) gradcheck to compare Jacobian matrix,

from torch.autograd import gradcheck
func = MyFunc.apply
input = (torch.randn(3,3,dtype=torch.double,requires_grad=True))
test = gradcheck(func, input)

and the result was

enter image description here

Please someone help me to get correct back propagation result

Thanks!


Q2.

Thanks for answers!

Because of your help, I could implement back propagation in case of (H,W) tensor.

However, while I implemented back propagation in case of (N,H,W) tensor, I got a problem. I think the problem would be initializing new tensor.

Here's my new code

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyFunc(torch.autograd.Function):
  @staticmethod
  def forward(ctx, input):
    ctx.save_for_backward(input)
    
    N = input.size(0)
    for n in range(N):
      input[n] /= torch.sum(input[n], dim=0)

    return input

  @staticmethod
  def backward(ctx, grad_output):
    input = ctx.saved_tensors[0]
    N, H, W = input.size()
    I = torch.eye(H).unsqueeze(-1)
    sum = input.sum(1)

    grad_input = torch.zeros((N,H,W), dtype = torch.double, requires_grad=True)
    for n in range(N):
      grad_input[n] = ((sum[n] * I - input[n]) * grad_output[n] / sum[n]**2).sum(1)

    return grad_input

Gradcheck code is

from torch.autograd import gradcheck
func = MyFunc.apply
input = (torch.rand(2,2,2,dtype=torch.double,requires_grad=True))
test = gradcheck(func, input)
print(test)

and result is enter image description here

I don't know why the error occurs...

Your help will be very helpful for me to implement my own convolutional network.

Thanks! Have a nice day.

Jin
  • 13
  • 2

3 Answers3

3

Let's look an example with a single column, for instance: [[x1], [x2], [x3]].

Let sum be x1 + x2 + x3, then normalizing x will give y = [[y1], [y2], [y3]] = [[x1/sum], [x2/sum], [x3/sum]]. You're looking for dL/dx1, dL/x2, and dL/x3 - we'll just write them as: dx1, dx2, and dx3. Same for all dL/dyi.

So dx1 is equal to dL/dy1*dy1/dx1 + dL/dy2*dy2/dx1 + dL/dy3*dy3/dx1. That's because x1 contributes to all ouput element on the corresponding column: y1, y2, and y3.

We have:

  • dy1/dx1 = d(x1/sum)/dx1 = (sum - x1)/sum²

  • dy2/dx1 = d(x2/sum)/dx1 = -x2/sum²

  • similarly, dy3/dx1 = d(x3/sum)/dx1 = -x3/sum²

Therefore dx1 = (sum - x1)/sum²*dy1 - x2/sum²*dy2 - x3/sum²*dy3. Same for dx2 and dx3. As a result, the Jacobian is [dxi]_i = (sum - xi)/sum² and [dxi]_j = -xj/sum² (for all j different to i).

In your implementation, you seem to be missing all non-diagonal components.

Keeping the same one-column example, with x1=2, x2=3, and x3=5:

>>> x = torch.tensor([[2.], [3.], [5.]])

>>> sum = input.sum(0)
tensor([10])

The Jacobian will be:

>>> J = (sum*torch.eye(input.size(0)) - input)/sum**2
tensor([[ 0.0800, -0.0200, -0.0200],
        [-0.0300,  0.0700, -0.0300],
        [-0.0500, -0.0500,  0.0500]])

For an implementation with multiple columns, it's a bit trickier, more specifically for the shape of the diagonal matrix. It's easier to keep the column axis last so we don't have to bother with broadcastings:

>>> x = torch.tensor([[2., 1], [3., 3], [5., 5]])
>>> sum = x.sum(0)
tensor([10.,  9.])

>>> diag = sum*torch.eye(3).unsqueeze(-1).repeat(1, 1, len(sum))
tensor([[[10.,  9.],
         [ 0.,  0.],
         [ 0.,  0.]],

        [[ 0.,  0.],
         [10.,  9.],
         [ 0.,  0.]],

        [[ 0.,  0.],
         [ 0.,  0.],
         [10.,  9.]]])

Above diag has a shape of (3, 3, 2) where the two columns are on the last axis. Notice how we didn't need to broadcast sum.

What I wouldn't have done is: torch.eye(3).unsqueeze(0).repeat(len(sum), 1, 1). Since with this kind of shape - (2, 3, 3) - you will have to use sum[:, None, None], and will need further broadcasting down the road...

The Jacobian is simply:

>>> J = (diag - x)/sum**2
tensor([[[ 0.0800,  0.0988],
         [-0.0300, -0.0370],
         [-0.0500, -0.0617]],

        [[-0.0200, -0.0123],
         [ 0.0700,  0.0741],
         [-0.0500, -0.0617]],

        [[-0.0200, -0.0123],
         [-0.0300, -0.0370],
         [ 0.0500,  0.0494]]])

You can check the results by backpropagating through the operation using an arbitrary dy vector (not with torch.ones though, you'll get 0s because of J!). After backpropagating, x.grad should equal to torch.einsum('abc,bc->ac', J, dy).

Ivan
  • 34,531
  • 8
  • 55
  • 100
  • 1
    I see you did the exercise ;) +1 – Shai Feb 08 '21 at 05:47
  • Thank you for your answer! I could implement back-propagation in case of (H,W) tensor – Jin Feb 08 '21 at 15:12
  • However, I got a problem in case of (N,H,W) tensor... Although I used for loop for each tensor in the batch(N), error occurred. I'll upload my code and an image. Could you help me to get correct answer for (N,H,W) tensor case?? Thanks! – Jin Feb 08 '21 at 15:17
1

Your Jacobian is not accurate: It is a 4d tensor, you only computed a 2D slice of it.

You neglected the second row of the Jacobian:

enter image description here

Shai
  • 111,146
  • 38
  • 238
  • 371
0

Answer for Q2.

I implemented back propagation myself for many batch case. I used unsqueeze function and it worked.

size of input : (N,H,W) (N is batch size)

forward:
  out = input / torch.sum(input, dim=1).unsqueeze(1)

backward:
  diag = torch.eye(input.size(1),  dtype=torch.double, requires_grad=True).unsqueeze(-1)
  sum = input.sum(1)
  grad_input = ((sum.unsqueeze(1).unsqueeze(1) * diag - input.unsqueeze(1)) * grad_out.unsqueeze(1) / (sum**2).unsqueeze(1).unsqueeze(1)).sum(2)
Jin
  • 13
  • 2