1

I have a simple neural-network


class NeuralNetwork(nn.Module):
    def __init__(self,n_inputs):
        super(NeuralNetwork, self).__init__()

        self.l1 = nn.Linear(n_inputs, 10) 


    def forward(self,x):
         return self.l1(x)
net = NeuralNetwork(n_inputs=10)
net = train_network(net)

torch.save(net,"./pretrained_models/model_cnn.pt")

Then when I want to load it again in another script, I can understand that we need to instantiate a class with the same name before it works e.g

import torch

class NeuralNetwork(nn.Module):
      pass

net = torch.load("./pre_trained_models/model_cnn.pt")
net.eval() 
X = load_data(data_path)
pred = net(X)# <---  NotImplementedError

Do I then need to re-write __init__ and forward again when loading? Im trying to load the entire model instead of the state_dict since I want to avoid having to re-write the network in production, when we change it in the development.

CutePoison
  • 4,679
  • 5
  • 28
  • 63
  • 1
    You can (and should) put `NeuralNetwork` in its own `net.py` file which you can import. – Mateen Ulhaq May 14 '21 at 12:06
  • So you would have a folder only containing the structure of the network, and then import that in the main file? Why do we have the `torch.load` method for loading entire networks when we are forced to keep the structure saved (just as when we use `load_state_dict`) – CutePoison May 14 '21 at 12:09
  • Does this answer your question? [Saving PyTorch model with no access to model class code](https://stackoverflow.com/questions/59287728/saving-pytorch-model-with-no-access-to-model-class-code) – GoodDeeds May 14 '21 at 12:14
  • It depends on the model -- some models do not have a static graph. If your model does, you can do what @GoodDeeds says. But putting the model definition in its own file, committing it to a version control system, and using the commit to track your models isn't too bad an idea... this way, you have a human-readable model definition that you can refer to for any given model file. – Mateen Ulhaq May 14 '21 at 12:17

0 Answers0