Saving and Loading Models

This notebook demonstrates how to save and load models with PyTorch. This is important because we often want to load previously trained models to use in making predictions or to continue training on new data.

%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])
# Download and load the training data
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', 
                                 download=True, 
                                 train=True, 
                                 transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, 
                                          batch_size=64, 
                                          shuffle=True)

# Download and load the test data
testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', 
                                download=True, 
                                train=False, 
                                transform=transform)
testloader = torch.utils.data.DataLoader(testset, 
                                         batch_size=64, 
                                         shuffle=True)

Helper function to show images.

def imshow(image, ax=None, title=None, normalize=True):
    if ax is None:
        fig, ax = plt.subplots()
    image = image.numpy().transpose((1, 2, 0))

    if normalize:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = std * image + mean
        image = np.clip(image, 0, 1)

    ax.imshow(image)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.tick_params(axis='both', length=0)
    ax.set_xticklabels('')
    ax.set_yticklabels('')

    return ax

Here we can see one of the images.

image, label = next(iter(trainloader))
imshow(image[0,:]);

png

Train a network

To make things more concise here, I moved the model architecture and training code from the last part to a file called fc_model. Importing this, we can easily create a fully-connected network with fc_model.Network, and train the network using fc_model.train. I’ll use this model (once it’s trained) to demonstrate how we can save and load models.

import torch
from torch import nn
import torch.nn.functional as F

class Network(nn.Module):
    def __init__(self, input_size, output_size, hidden_layers, drop_p=0.5):
        super().__init__()
        
        # Input to a hidden layer
        self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_layers[0])])
        
        # Add a variable number of more hidden layers
        layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])
        self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes])
        
        self.output = nn.Linear(hidden_layers[-1], output_size)
        
        self.dropout = nn.Dropout(p=drop_p)
        
    def forward(self, x):
        for each in self.hidden_layers:
            x = F.relu(each(x))
            x = self.dropout(x)
        x = self.output(x)
        
        return F.log_softmax(x, dim=1)
def validation(model, testloader, criterion):
    accuracy, test_loss = 0, 0
    for images, labels in testloader:

        images = images.resize_(images.size()[0], 784)

        output = model.forward(images)
        test_loss += criterion(output, labels).item()

        ## Calculating the accuracy 
        # Model's output is log-softmax, take exponential to get the probabilities
        ps = torch.exp(output)
        
        # Class with highest probability is our predicted class, compare with true label
        equality = (labels.data == ps.max(1)[1])
        
        # Accuracy is number of correct predictions divided by all predictions, 
        #      so just take the mean
        accuracy += equality.type_as(torch.FloatTensor()).mean()

    return test_loss, accuracy
def train(model, trainloader, testloader, criterion, optimizer, epochs=5, print_every=40):
    
    running_loss = 0
    print('Epoch\tStep\tTraining Loss\tTest Loss\tTest Accuracy')
    for e in range(epochs):
        # Reset Step Counter
        step = 0
        
        # Model in training mode, dropout is on
        model.train()
        
        for images, labels in trainloader:
            
            # Flatten images into a 784 long vector
            images.resize_(images.size()[0], 784)
            
            optimizer.zero_grad()
            
            output = model.forward(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

            if step % print_every == 0:
                # Model in inference mode, dropout is off
                model.eval()
                
                # Turn off gradients for validation, will speed up inference
                with torch.no_grad():
                    test_loss, accuracy = validation(model, testloader, criterion)
                
                print("{:3}/{}\t{:4}\t{:13.3f}\t{:9.3f}\t{:13.3f}"
                      .format(e+1, epochs, 
                              step, 
                              running_loss/len(trainloader),
                              test_loss/len(testloader), 
                              accuracy/len(testloader)))

                running_loss = 0
                
                # Make sure dropout and grads are on for training
                model.train()
            
            step += 1

Create the network, define the criterion and optimizer, and train

model = Network(784, 10, [512, 256, 128])
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train(model, trainloader, testloader, criterion, optimizer, epochs=3, print_every=784/2)
Epoch	Step	Training Loss	Test Loss	Test Accuracy
  1/3	   0	        0.002	    2.282	        0.178
  1/3	 392	        0.362	    0.540	        0.799
  1/3	 784	        0.252	    0.495	        0.823
  2/3	   0	        0.091	    0.494	        0.811
  2/3	 392	        0.229	    0.468	        0.830
  2/3	 784	        0.220	    0.458	        0.837
  3/3	   0	        0.085	    0.447	        0.842
  3/3	 392	        0.209	    0.427	        0.842
  3/3	 784	        0.205	    0.417	        0.848

Saving and loading networks

As you can imagine, it’s impractical to train a network every time you need to use it. Instead, we can save trained networks then load them later to train more or use them for predictions.

The parameters for PyTorch networks are stored in a model’s state_dict. We can see the state dict contains the weight and bias matrices for each of our layers.

print("Our model: \n\n", model, '\n')
print("The state dict keys:\n")

for key in list(model.state_dict().keys()):
    print(key)
Our model: 

 Network(
  (hidden_layers): ModuleList(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=128, bias=True)
  )
  (output): Linear(in_features=128, out_features=10, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
) 

The state dict keys:

hidden_layers.0.weight
hidden_layers.0.bias
hidden_layers.1.weight
hidden_layers.1.bias
hidden_layers.2.weight
hidden_layers.2.bias
output.weight
output.bias

The simplest thing to do is simply save the state dict with torch.save. For example, we can save it to a file 'checkpoint.pth'.

torch.save(model.state_dict(), 
           'saving-models/checkpoint.pth')

Note that the file is relatively large at 2.3 MB.

%%bash
cd saving-models
ls -l
total 2224
-rw-r--r-- 1 ryan ryan 2271185 May 29 06:51 checkpoint.pth
-rw-r--r-- 1 ryan ryan    2603 May 28 06:07 saving-models_7_0.png

Then we can load the state dict with torch.load.

state_dict = torch.load('saving-models/checkpoint.pth')

for key in list(model.state_dict().keys()):
    print(key)
hidden_layers.0.weight
hidden_layers.0.bias
hidden_layers.1.weight
hidden_layers.1.bias
hidden_layers.2.weight
hidden_layers.2.bias
output.weight
output.bias

And to load the state dict in to the network, you do model.load_state_dict(state_dict).

model.load_state_dict(state_dict)
<All keys matched successfully>

Seems pretty straightforward, but as usual it’s a bit more complicated. Loading the state dict works only if the model architecture is exactly the same as the checkpoint architecture. If you create a model with a different architecture, this fails:

# Try this
model = Network(784, 10, [400, 200, 100])
# This will throw an error because the tensor sizes are wrong!
model.load_state_dict(state_dict)
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-14-796f2ae97f9b> in <module>
      2 model = Network(784, 10, [400, 200, 100])
      3 # This will throw an error because the tensor sizes are wrong!
----> 4 model.load_state_dict(state_dict)


~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    845         if len(error_msgs) > 0:
    846             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 847                                self.__class__.__name__, "\n\t".join(error_msgs)))
    848         return _IncompatibleKeys(missing_keys, unexpected_keys)
    849 


RuntimeError: Error(s) in loading state_dict for Network:
	size mismatch for hidden_layers.0.weight: copying a param with shape torch.Size([512, 784]) from checkpoint, the shape in current model is torch.Size([400, 784]).
	size mismatch for hidden_layers.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([400]).
	size mismatch for hidden_layers.1.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([200, 400]).
	size mismatch for hidden_layers.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([200]).
	size mismatch for hidden_layers.2.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([100, 200]).
	size mismatch for hidden_layers.2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([100]).
	size mismatch for output.weight: copying a param with shape torch.Size([10, 128]) from checkpoint, the shape in current model is torch.Size([10, 100]).

This means we need to rebuild the model exactly as it was when trained. Information about the model architecture needs to be saved in the checkpoint, along with the state dict. To do this, you build a dictionary with all the information you need to compeletely rebuild the model.

checkpoint = {'input_size': 784,
              'output_size': 10,
              'hidden_layers': [each.out_features for each in model.hidden_layers],
              'state_dict': model.state_dict()}

torch.save(checkpoint, 'saving-models/checkpoint.pth')

Now the checkpoint has all the necessary information to rebuild the trained model. You can easily make that a function if you want. Similarly, we can write a function to load checkpoints.

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = Network(checkpoint['input_size'],
                    checkpoint['output_size'],
                    checkpoint['hidden_layers'])
    model.load_state_dict(checkpoint['state_dict'])
    
    return model
model = load_checkpoint('saving-models/checkpoint.pth')
print(model)