0

Why do the bytes obtained from serializing pytorch state_dicts change after loading the a state_dict into a new instance of the same model architecture?

Have a look:

import binascii
import torch.nn as nn
import pickle

lin1 = nn.Linear(1, 1, bias=False)
lin1s = pickle.dumps(lin1.state_dict())
print("--- original model ---")
print(f"hash of state dict: {hex(binascii.crc32(lin1s))}")
print(f"weight: {lin1.state_dict()['weight'].item()}")

lin2 = nn.Linear(1, 1, bias=False)
lin2.load_state_dict(pickle.loads(lin1s))
lin2s = pickle.dumps(lin2.state_dict())
print("\n--- model from deserialized state dict ---")
print(f"hash of state dict: {hex(binascii.crc32(lin2s))}")
print(f"weight: {lin2.state_dict()['weight'].item()}")

prints

--- original model ---
hash of state dict: 0x4806e6b6
weight: -0.30337071418762207

--- model from deserialized state dict ---
hash of state dict: 0xe2881422
weight: -0.30337071418762207

As you can see, the hashes of the (pickles of the) state_dicts are different whereas the weight is copied over correctly. I would assume that a state_dict from the new model equals the old one in every aspect. Seemingly, it does not, hence the different hashes.

jonas
  • 181
  • 3
  • 8
  • 1
    This might be because pickle is not expected to produce a repr suitable for hashing. https://stackoverflow.com/questions/12727571/using-pickle-dumps-to-hash-mutable-objects/12739361#12739361 It might be a better idea to compare keys, and then compare tensors stored in the dict-keys for equality/closeness. – Umang Gupta Mar 29 '21 at 21:53
  • @UmangGupta that sounds very plausible, I did not know that. Thanks! I don't know if your answer (that's still kind of a hypothesis) counts as "real" answer, but I'll mark it as solution if you post it as such. – jonas Mar 30 '21 at 18:50
  • I elaborated my response further in the answer. I hope you find it helpful. – Umang Gupta Mar 30 '21 at 19:21

1 Answers1

1

This might be because pickle is not expected to produce a repr suitable for hashing (See Using pickle.dumps to hash mutable objects). It might be a better idea to compare keys, and then compare tensors stored in the dict-keys for equality/closeness.

Below is a rough implementation of that idea.

def compare_state_dict(dict1, dict2):
    # compare keys
    for key in dict1:
        if key not in dict2:
            return False
    
    for key in dict2:
        if key not in dict1:
            return False

    for (k,v) in dict1.items():
        if not torch.all(torch.isclose(v, dict2[k]))
            return False
    
    return True

However, if you would still like to hash a state-dict and avoid using comparisons like isclose above, you can use a function like below.

def dict_hash(dictionary):
    for (k,v) in dictionary.items():
        # it did not work without hashing the tensor
        dictionary[k] = hash(v)

    # dictionaries are not hashable and need to be converted to frozenset. 
    return hash(frozenset(sorted(dictionary.items(), key=lambda x: x[0])))
Umang Gupta
  • 15,022
  • 6
  • 48
  • 66
  • 1
    thanks but I stumbled upon one problem: hashing equal tensors yields different hashes (I guess it doesn't hash the content but the memory address (?)). This can be solved by hashing as follows instead: hash(v.numpy().tobytes()) – jonas Apr 03 '21 at 11:32
  • I am 100% sure, but you might be correct. Also, `isclose` might be a better choice, since it is almost going to take the same amount of work to hash as well as compare... one/two pass through the data. – Umang Gupta Apr 03 '21 at 16:22