I am training a model in pytorch for which I have made a class like so:
from torch import nn
class myNN(nn.Module):
def __init__(self, dense1=128, dense2=64, dense3=32, ...):
self.MLP = nn.Sequential(
nn.Linear(dense1, dense2),
nn.ReLU(),
nn.Linear(dense2, dense3),
nn.ReLU(),
nn.Linear(dense3, 1)
)
...
In order to save it I am using:
torch.save(model.state_dict(), checkpoint_model_path)
and to load it I am using:
model = myNN() # or with specified parameters
model.load_state_dict(torch.load(model_file))
However, in order for this method to work I have to use the right values in myNN()'s constructor. That means that I would need to somehow remember or store which parameters (layer sizes) I have used in each case in order to properly load different models.
Is there a flexible way to save/load models in pytorch where I would also read the size of the layers?
E.g. by loading a myNN() object directly or somehow reading the layer sizes from the saved pickle file?
I am hesitant to try the second method in Best way to save a trained model in PyTorch? due to the warnings mentioned there. Is there a better way to achieve what I want?