I'm trying to understand why I cannot directly overwrite the weights of a torch layer. Consider the following example:
import torch
from torch import nn
net = nn.Linear(3, 1)
weights = torch.zeros(1,3)
# Overwriting does not work
net.state_dict()["weight"] = weights # nothing happens
print(f"{net.state_dict()['weight']=}")
# But mutating does work
net.state_dict()["weight"][0] = weights # indexing works
print(f"{net.state_dict()['weight']=}")
#########
# output
: net.state_dict()['weight']=tensor([[ 0.5464, -0.4110, -0.1063]])
: net.state_dict()['weight']=tensor([[0., 0., 0.]])
I'm confused since state_dict()["weight"]
is just a torch tensor, so I feel I'm missing something really obvious here.