During prototyping, I often perform numerous changes to a PyTorch model. For instance, suppose the first model I am experimenting with is:
class Model(nn.Module):
def __init__(self, **args):
super().__init__()
self.l1 = nn.Linear(128, 1)
then I will add another layer:
class Model(nn.Module):
def __init__(self, **args):
super().__init__()
self.l1 = nn.Linear(128, 32)
self.l2 = nn.Linear(32, 1)
or maybe add some convolutions, and so on.
The problem is that I often get disorganized the more experiments I perform, as I haven't found a straightforward way of saving both model definition and its weights so that I can load a previous state.
I know I can do:
torch.save({'model': Model(), 'state': model.state_dict()}, path)
# or directly
torch.save(model, path)
but then loading the model also requires that the model class (here, Model
) to exist in the current file.
In Keras you can simply do:
model = ... # Get model (Sequential, Functional Model, or Model subclass)
model.save('path/to/location')
which saves model's architecture/config and weights, among other things. This means that you can load the model without having defined the architecture:
model = keras.models.load_model('path/to/location')
Refering to Keras model saving:
The SavedModel and HDF5 file contains:
- the model's configuration (topology)
- the model's weights
- the model's optimizer's state (if any)
Thus models can be reinstantiated in the exact same state, without any of the code used for model definition or training.
This is what I want to achieve in PyTorch.
Is there a similar approach for PyTorch? What is the best practice for these situations?