I have a simple neural-network
class NeuralNetwork(nn.Module):
def __init__(self,n_inputs):
super(NeuralNetwork, self).__init__()
self.l1 = nn.Linear(n_inputs, 10)
def forward(self,x):
return self.l1(x)
net = NeuralNetwork(n_inputs=10)
net = train_network(net)
torch.save(net,"./pretrained_models/model_cnn.pt")
Then when I want to load it again in another script, I can understand that we need to instantiate a class with the same name before it works e.g
import torch
class NeuralNetwork(nn.Module):
pass
net = torch.load("./pre_trained_models/model_cnn.pt")
net.eval()
X = load_data(data_path)
pred = net(X)# <--- NotImplementedError
Do I then need to re-write __init__
and forward
again when loading? Im trying to load the entire model instead of the state_dict
since I want to avoid having to re-write the network in production, when we change it in the development.