0

I'm using PyTorch and trying to register hooks on model parameters. The following code creates lambda functions to add to each model parameter, so I can see in the hook which tensor the gradient belongs to

import torch
import torchvision

# define model and random train batch
model = torchvision.models.alexnet()
input = torch.rand(10, 3, 224, 224)   # batch of 10 images
targets = torch.zeros(10).long()

def grad_hook_template(param, name, grad):
    print(f'Receive grad for {name} w whape {grad.shape}')

# add one lambda hook to each parameter
for name, param in model.named_parameters():
    print(f'Register hook for {name}')

    # use a lambda so we can pass additional information to the hook, which should only take one parameter
    param.register_hook(lambda grad: grad_hook_template(param, name, grad))

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()

prediction = model(input)
loss = loss_fn(prediction, targets)
loss.backward()
optimizer.step()

The result is that the name and param arguments to grad_hook_template always as the same value (and id), but the grad argument is always different (as expected). Why is it that when I register the hook, the lambdas seem to refer to the same local variables each time?

I read e.g. here that loops do not create new scopes and closures are lexical in Python, i.e. the name and param which I'm passing to the lambda are just pointers and whatever value they have at the end of the loop is seen by everyone with this pointer. But what can I do about it? copy.copy() the variables?

oarfish
  • 4,116
  • 4
  • 37
  • 66

2 Answers2

0

This is kind of answered by the FAQ.

Solutions include

  • using functools.partial instead of lambda
  • using default params to the lambdas to capture the value of the variables
oarfish
  • 4,116
  • 4
  • 37
  • 66
0

You've run into late binding closures. The variables param and name are looked up at call time, not when the function they're used in is defined. By the time any of these functions are called, name and param are at the last values in the loop. To get around this, you could do this:

for name, param in model.named_parameters():
    print(f'Register hook for {name}')
    param.register_hook(lambda grad, name=name, param=param: grad_hook_template(param, name, grad))

However, I think using functools.partial is the right solution here:

from functools import partial

for name, param in model.named_parameters():
    print(f'Register hook for {name}')
    param.register_hook(partial(grad_hook_template, name=name, param=param))

You can find more information about late binding closures at the Common Gotchas page of the Hitchhiker's Guide to Python as well as in the Python docs.

Note that this applies equally to functions defined with the def keyword.

Cyphase
  • 11,502
  • 2
  • 31
  • 32