The pytoch model is saved and loaded, and the training is continued on the basis of the loaded model

pytorch is very simple to save the model. There are two main methods:

  1. Save only parameters; (official recommendation)
  2. Save the entire model (structure + parameters).
    Because saving the whole model will consume a lot of storage, it is officially recommended to save only parameters and then load them on the basis of building the model. This paper introduces two methods, but only gives an example of the first method.

1. Save parameters only

1) Preserve

Generally, parameters can be saved with one statement:

torch.save(model.state_dict(), path)
Where model refers to the defined model instance variables, such as model = vgg16 (), path is the path to save parameters, such as path = '. / model. PTH', path = '. / model. Tar', path = '. / model. Pkl', and the file to save parameters must have a suffix extension.
In particular, if you want to save the optimizer, epochs and other information used in a training, you can combine these information to form a dictionary, and then save the dictionary:
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, path)

2) Loading

For the first case above, only one sentence is needed to load the model:

model.load_state_dict(torch.load(path))

For the above second method of saving in dictionary form, the loading method is as follows:

checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint(['epoch'])
It should be noted that when loading the method of saving parameters only, the model consistent with the original model should be defined in advance and loaded on the instance object of the model (assuming the name is model), that is, before using the above loading statement, a Net like the original model has been defined and instantiated model=Net().
The following is a specific example, which only saves the latest parameters:
import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import torch.backends.cudnn as cudnn
import datetime
import argparse

# Parameter declaration
batch_size = 32
epochs = 10
WORKERS = 0   # Number of dataloder threads
test_flag = True  #Test flag. When True, load the saved model for testing 
ROOT = '/home/pxt/pytorch/cifar'  # MNIST dataset save path
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # Model save path

# Load MNIST dataset
transform = tv.transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)

train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)


# Construction model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)
    
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = Net().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)


# model training
def train(model, train_loader, epoch):
    model.train()
    train_loss = 0
    for i, data in enumerate(train_loader, 0):
        x, y = data
        x = x.cuda()
        y = y.cuda()
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss
    loss_mean = train_loss / (i+1)
    print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))

# Model test
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            x, y = data
            x = x.cuda()
            y = y.cuda()
            optimizer.zero_grad()
            y_hat = model(x)
            test_loss += criterion(y_hat, y).item()
            pred = y_hat.max(1, keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
        test_loss /= (i+1)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data), 100. * correct / len(test_data)))


def main():

    # If test_flag=True, the saved model is loaded
    if test_flag:
        # Load the saved model and directly perform the test machine verification without subsequent steps of this module
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epochs = checkpoint['epoch']
        test(model, test_load)
        return

    for epoch in range(0, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        # Save model
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)

if __name__ == '__main__':
    main()

3) Continue training based on the loaded model

When training the model, the program may be interrupted due to some problems, or it is often necessary to observe the changes of training conditions to change parameters such as learning rate. At this time, it is necessary to load the model saved before interruption and continue training on this basis. At this time, only modify the main() function in the above example. The modified main() function is as follows:
def main():

    # If test_flag=True, the saved model is loaded
    if test_flag:
        # Load the saved model and directly perform the test machine verification without subsequent steps of this module
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        test(model, test_load)
        return

    # If there is a saved model, load the model and continue training based on it
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('load epoch {} success!'.format(start_epoch))
    else:
        start_epoch = 0
        print('No saved model, training will start from scratch!')

    for epoch in range(start_epoch+1, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        # Save model
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)

  

2. Save the whole model

1) Preserve

torch.save(model, path)

2) Loading

model = torch.load(path)

  

Tags: Deep Learning

Posted on Wed, 24 Nov 2021 13:27:41 -0500 by mauri_gato