pytorch is very simple to save the model. There are two main methods:
- Save only parameters; (official recommendation)
- Save the entire model (structure + parameters).
Because saving the whole model will consume a lot of storage, it is officially recommended to save only parameters and then load them on the basis of building the model. This paper introduces two methods, but only gives an example of the first method.
1. Save parameters only
1) Preserve
Generally, parameters can be saved with one statement:
torch.save(model.state_dict(), path)
Where model refers to the defined model instance variables, such as model = vgg16 (), path is the path to save parameters, such as path = '. / model. PTH', path = '. / model. Tar', path = '. / model. Pkl', and the file to save parameters must have a suffix extension.
In particular, if you want to save the optimizer, epochs and other information used in a training, you can combine these information to form a dictionary, and then save the dictionary:state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch} torch.save(state, path)
2) Loading
For the first case above, only one sentence is needed to load the model:
model.load_state_dict(torch.load(path))
For the above second method of saving in dictionary form, the loading method is as follows:
checkpoint = torch.load(path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) epoch = checkpoint(['epoch'])
It should be noted that when loading the method of saving parameters only, the model consistent with the original model should be defined in advance and loaded on the instance object of the model (assuming the name is model), that is, before using the above loading statement, a Net like the original model has been defined and instantiated model=Net().
The following is a specific example, which only saves the latest parameters:
import torch as torch import torchvision as tv import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision.transforms as transforms from torchvision.transforms import ToPILImage import torch.backends.cudnn as cudnn import datetime import argparse # Parameter declaration batch_size = 32 epochs = 10 WORKERS = 0 # Number of dataloder threads test_flag = True #Test flag. When True, load the saved model for testing ROOT = '/home/pxt/pytorch/cifar' # MNIST dataset save path log_dir = '/home/pxt/pytorch/logs/cifar_model.pth' # Model save path # Load MNIST dataset transform = tv.transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform) test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform) train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS) test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS) # Construction model class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 64, 3, padding=1) self.conv2 = nn.Conv2d(64, 128, 3, padding=1) self.conv3 = nn.Conv2d(128, 256, 3, padding=1) self.conv4 = nn.Conv2d(256, 256, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(256 * 8 * 8, 1024) self.fc2 = nn.Linear(1024, 256) self.fc3 = nn.Linear(256, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = self.pool(F.relu(self.conv2(x))) x = F.relu(self.conv3(x)) x = self.pool(F.relu(self.conv4(x))) x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3]) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = Net().cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01) # model training def train(model, train_loader, epoch): model.train() train_loss = 0 for i, data in enumerate(train_loader, 0): x, y = data x = x.cuda() y = y.cuda() optimizer.zero_grad() y_hat = model(x) loss = criterion(y_hat, y) loss.backward() optimizer.step() train_loss += loss loss_mean = train_loss / (i+1) print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item())) # Model test def test(model, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for i, data in enumerate(test_loader, 0): x, y = data x = x.cuda() y = y.cuda() optimizer.zero_grad() y_hat = model(x) test_loss += criterion(y_hat, y).item() pred = y_hat.max(1, keepdim=True)[1] correct += pred.eq(y.view_as(pred)).sum().item() test_loss /= (i+1) print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_data), 100. * correct / len(test_data))) def main(): # If test_flag=True, the saved model is loaded if test_flag: # Load the saved model and directly perform the test machine verification without subsequent steps of this module checkpoint = torch.load(log_dir) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) epochs = checkpoint['epoch'] test(model, test_load) return for epoch in range(0, epochs): train(model, train_load, epoch) test(model, test_load) # Save model state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} torch.save(state, log_dir) if __name__ == '__main__': main()
3) Continue training based on the loaded model
When training the model, the program may be interrupted due to some problems, or it is often necessary to observe the changes of training conditions to change parameters such as learning rate. At this time, it is necessary to load the model saved before interruption and continue training on this basis. At this time, only modify the main() function in the above example. The modified main() function is as follows:
def main(): # If test_flag=True, the saved model is loaded if test_flag: # Load the saved model and directly perform the test machine verification without subsequent steps of this module checkpoint = torch.load(log_dir) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] test(model, test_load) return # If there is a saved model, load the model and continue training based on it if os.path.exists(log_dir): checkpoint = torch.load(log_dir) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] print('load epoch {} success!'.format(start_epoch)) else: start_epoch = 0 print('No saved model, training will start from scratch!') for epoch in range(start_epoch+1, epochs): train(model, train_load, epoch) test(model, test_load) # Save model state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} torch.save(state, log_dir)
2. Save the whole model
1) Preserve
torch.save(model, path)
2) Loading
model = torch.load(path)