Why do the bytes obtained from serializing pytorch state_dict
s change after loading the a state_dict
into a new instance of the same model architecture?
Have a look:
import binascii
import torch.nn as nn
import pickle
lin1 = nn.Linear(1, 1, bias=False)
lin1s = pickle.dumps(lin1.state_dict())
print("--- original model ---")
print(f"hash of state dict: {hex(binascii.crc32(lin1s))}")
print(f"weight: {lin1.state_dict()['weight'].item()}")
lin2 = nn.Linear(1, 1, bias=False)
lin2.load_state_dict(pickle.loads(lin1s))
lin2s = pickle.dumps(lin2.state_dict())
print("\n--- model from deserialized state dict ---")
print(f"hash of state dict: {hex(binascii.crc32(lin2s))}")
print(f"weight: {lin2.state_dict()['weight'].item()}")
prints
--- original model ---
hash of state dict: 0x4806e6b6
weight: -0.30337071418762207
--- model from deserialized state dict ---
hash of state dict: 0xe2881422
weight: -0.30337071418762207
As you can see, the hashes of the (pickles of the) state_dict
s are different whereas the weight is copied over correctly. I would assume that a state_dict
from the new model equals the old one in every aspect. Seemingly, it does not, hence the different hashes.