I'm using PyTorch and trying to register hooks on model parameters. The following code creates lambda functions to add to each model parameter, so I can see in the hook which tensor the gradient belongs to
import torch
import torchvision
# define model and random train batch
model = torchvision.models.alexnet()
input = torch.rand(10, 3, 224, 224) # batch of 10 images
targets = torch.zeros(10).long()
def grad_hook_template(param, name, grad):
print(f'Receive grad for {name} w whape {grad.shape}')
# add one lambda hook to each parameter
for name, param in model.named_parameters():
print(f'Register hook for {name}')
# use a lambda so we can pass additional information to the hook, which should only take one parameter
param.register_hook(lambda grad: grad_hook_template(param, name, grad))
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()
prediction = model(input)
loss = loss_fn(prediction, targets)
loss.backward()
optimizer.step()
The result is that the name
and param
arguments to grad_hook_template
always as the same value (and id
), but the grad
argument is always different (as expected). Why is it that when I register the hook, the lambdas seem to refer to the same local variables each time?
I read e.g. here that loops do not create new scopes and closures are lexical in Python, i.e. the name
and param
which I'm passing to the lambda are just pointers and whatever value they have at the end of the loop is seen by everyone with this pointer. But what can I do about it? copy.copy()
the variables?