0
import torch
import torchvision

n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./files', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./files', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)


import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)

train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

def train(epoch):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

def test():
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

test()
for epoch in range(1, n_epochs + 1):
  train(epoch)
  test()

torch.save(network.state_dict(), './results/model.pth')

Other file:

PATH = "results/model.pth"
model = torch.load(PATH)

When this is called, instead of loading the model parameters, Pytorch retrains the entire model. The model is just retrained the same way (ie. they take the exact same steps to get to the same local minimum).

PATH = "results/model.pth"
model = Net()
model.load_state_dict(torch.load(PATH))

has the same result.

Is there any way I can load the model without retraining the whole thing?

Timothy Oh
  • 147
  • 5
  • Please edit your question to include the implementation of `Net`. – jkr Nov 21 '22 at 23:03
  • 1
    How exactly do you know that the model is being retrained? – Dr. Snoopy Nov 22 '22 at 00:06
  • because when I run the other file, the code that calls the train function (in the original file that tells me the accuracy and loss for every epoch that is trained) starts running. Basically, running the other file results in the terminal outputting epoch1, epoch2, etc. – Timothy Oh Nov 22 '22 at 00:19
  • 1
    That might be because your script does not have the if _ _ name _ _ == " _ _ main _ _": to separate script and importing as a library. – Dr. Snoopy Nov 22 '22 at 00:30
  • 1
    In particular see https://stackoverflow.com/questions/419163/what-does-if-name-main-do and "If you import the guardless script in another script, then the latter script will trigger the former to run at import time and using the second script's command line arguments. This is almost always a mistake." – Dr. Snoopy Nov 22 '22 at 00:36
  • thank you so much. I just copied and pasted the Net class so that the script doesn't run. – Timothy Oh Nov 22 '22 at 00:59

1 Answers1

0

I just tried executing the code, and it works perfect. load_state_dict did not retrain the model:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

network = Net()
PATH    = "results/model.pth"
network.load_state_dict(torch.load(PATH))
# works perfect

By the way, state_dict only contains the model weights and not the dataset, so load_state_dict can never re-train the model.

I think the problem is how the original code is organized. The tranining procedure starts running inmediately after Class Net is defined, so you cannot import Net from this file without re-running everything. Ideally, the training and the testing procedure should be wrapped inside an if __name__=='__main__' statement (at the end of the file), so that you can safely import Net without re-running any calculations:

# source_file.py
import torch
import torchvision

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)


def train(epoch):
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))
        train_losses.append(loss.item())
        train_counter.append(
            (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

def test():
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
        output = network(data)
        test_loss += F.nll_loss(output, target, size_average=False).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

if __name__ == '__main__':
    
    n_epochs = 3
    batch_size_train = 64
    batch_size_test = 1000
    learning_rate = 0.01
    momentum = 0.5
    log_interval = 10
    
    random_seed = 1
    torch.backends.cudnn.enabled = False
    torch.manual_seed(random_seed)
    
    train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./files', train=True, download=True,
                                transform=torchvision.transforms.Compose([
                                torchvision.transforms.ToTensor(),
                                torchvision.transforms.Normalize(
                                    (0.1307,), (0.3081,))
                                ])),
    batch_size=batch_size_train, shuffle=True)
    
    test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./files', train=False, download=True,
                                transform=torchvision.transforms.Compose([
                                torchvision.transforms.ToTensor(),
                                torchvision.transforms.Normalize(
                                    (0.1307,), (0.3081,))
                                ])),
    batch_size=batch_size_test, shuffle=True)
    
    examples = enumerate(test_loader)
    batch_idx, (example_data, example_targets) = next(examples)
    
    
    network = Net()
    optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                        momentum=momentum)
    
    train_losses = []
    train_counter = []
    test_losses = []
    test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]
    
    
    test()
    for epoch in range(1, n_epochs + 1):
        train(epoch)
        test()
    
    PATH = './results/model.pth'
    torch.save(network.state_dict(), PATH)

Then, when you reload the model in a second file, you can write:

from my_source_file import Net
network = Net()
network.load_state_dict(torch.load(PATH))

Here is a website with more information about if __name__ == '__main__': https://realpython.com/if-name-main-python/

PS. Another option, that I personally use, is to define the neural network in a separate file than the training procedure. This is useful to make big projects look more organized, or even to experiment with different neural network designs.


Previous answer:

We should use load_state_dict to restore models:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

https://pytorch.org/tutorials/beginner/saving_loading_models.html

C-3PO
  • 1,181
  • 9
  • 17
  • This was my initial code, and it still just retrains the whole thing. – Timothy Oh Nov 21 '22 at 22:51
  • 2
    No, `load_state_dict` will not re-train the model. I just use it every day. By the way, to re-train a model, you explicitly have to loop over epochs. Maybe your code is very complex and something else is going on? `load_state_dict` is not guilty, I think. – C-3PO Nov 21 '22 at 23:45
  • 1
    Also, what I cited are the standard PyTorch instructions. – C-3PO Nov 21 '22 at 23:45
  • 1
    Hm. Interesting. I think my Pytorch just hates me :D Thank you – Timothy Oh Nov 22 '22 at 00:02