0

I have a model in pytorch. The model can take any shape but lets assume this is the model

torch_model =  Sequential(
    Flatten(),
    Linear(28 * 28, 256),
    Dropout(.4),
    ReLU(),
    BatchNorm1d(256),
    ReLU(),
    Linear(256, 128),
    Dropout(.4),
    ReLU(),
    BatchNorm1d(128),
    ReLU(),
    Linear(128, 10),
    Softmax()
)

I am using SGD optimizer, I want to set the gradient for each of the layers so the SGD algorithm will move the parameters in the direction I want.

Lets say I want all the gradients for all the layers to be ones (torch.ones_like(gradient_shape)) how can I do this? Thanks?

Yedidya kfir
  • 1,419
  • 3
  • 17
  • 32

1 Answers1

0

In PyTorch, with a model defined as yours above, you can iterate over the layers like this:

for layer in list(torch_model.modules())[1:]:
  print(layer)

You have to add the [1:] since the first module returned is the sequential module itself. In any layer, you can access the weights with layer.weight. However, it is important to remember that some layers, like Flatten and Dropout, don't have weights. A way to check, and then add 1 to each weight would be:

for layer in list(torch_model.modules())[1:]:
  if hasattr(layer, 'weight'):
    with torch.no_grad():
      for i in range(layer.weight.shape[0]):
          layer.weight[i] = layer.weight[i] + 1

I tested the above on your model and it does add 1 to every weight. Worth noting that it won't work without torch.no_grad() as you don't want pytorch tracking the changes.

StBlaize
  • 523
  • 3
  • 8
  • Is there a way to iterate over the grad of each layer? It seems that each layer put the parameters under a different name – Yedidya kfir May 08 '22 at 19:45
  • I'm sure there is. Are you using tensorflow or PyTorch? – StBlaize May 09 '22 at 00:17
  • I am using Pytorch – Yedidya kfir May 09 '22 at 04:07
  • If you are using `nn.Sequential`, you can get a list of all the modules in the sequence with `list(model.sequence_name.modules())` If not, you can make your class iterable yourself: https://stackoverflow.com/questions/19151/how-to-build-a-basic-iterator – StBlaize May 09 '22 at 10:14
  • can you add a full code example in your answer? I dont see how do you reach and change the gradient (for example I cant find parameter sequence_name or how to find where the gradient suppose to be?) – Yedidya kfir May 10 '22 at 08:38
  • I changed the answer. Should answer in full for you now. – StBlaize May 10 '22 at 22:50