PyTorch for neural style transfer

To view the content of an article with a picture, go to http://studyai.com/pytorch-1.4/advanced/neural_style_tutorial.html

This tutorial describes how to implement the neural style algorithm developed by Leon A. gates. Neural style, or neural transfer, allows you to take a new artistic style of image and reproduction. The algorithm accepts input image, content image and style image, and modifies the input to make it similar to content image and style image. content1

Bottom principle

The principle is simple: we define two distances, one for content (DC ), one for style (DS). DC measures how different the contents of two images are, while DS

Measure how different the styles are between the two images. Then, we take the third image as input and transform it to minimize its content distance from the content image and its style distance from the style image. Now we can import the necessary packages and start the natural transfer. Import package and select device

The packages listed below are all the packages used to implement the natural transfer.

torch, torch.nn, numpy
 torch.optim (efficient optimization package of gradient descent algorithm)
PIL, PIL.Image, matplotlib.pyplot (package for loading and displaying images)
torchvision.transforms (convert PIL image to tensors)
Torch vision. Models (models for training and loading pre training)
copy (system package)
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from PIL import Image
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
import torchvision.models as models

import copy

Next, we need to choose which device to run the network on and import content and style images. It takes longer to run the neural transfer algorithm on large images, and it runs much faster on GPU. We can use torch.cuda.is_available() to check if there is a GPU available. Next, we set up torch.device to use throughout the tutorial. In addition, the. to(device) method is used to move a tensor or module to the desired device.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Loading images

Now we will import the style image and content image. The values of the original PIL images are between 0 and 255, but when converted to torch tensors, their values are between 0 and 1. Images also need to be resized to have the same size. An important detail to note is that the tensor value of the neural network in the torch library changes from 0 to 1. If you try to provide a tensor image with values from 0 to 255 to the network, the active feature mapping will not feel the expected content and style. However, the pre training network from the Caffe library is trained to tensor images of 0 to 255

Note

Here are the download addresses for the two images used in this tutorial: picasso.jpg and dancing.jpg. Download the two images and put them in the folder named images in your current working directory.

# Desired size of output image
imsize = 512 if torch.cuda.is_available() else 128  # If you don't have a GPU, make it smaller

loader = transforms.Compose([
    transforms.Resize(imsize),  # Scale imported image
    transforms.ToTensor()])  # Convert it to torch sensor


def image_loader(image_name):
    image = Image.open(image_name)
    # Virtual batch dimension, in order to meet the latitude requirements of network input
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)


style_img = image_loader("./data/images/neural-style/picasso.jpg")
content_img = image_loader("./data/images/neural-style/dancing.jpg")

assert style_img.size() == content_img.size(), \
    "we need to import style and content images of the same size"

Now, let's create a function to display the image by converting the copy of the image to PIL format and using plt.imshow to display the copy. We will try to display content and style images to make sure they are imported correctly.

unloader = transforms.ToPILImage()  # Convert to PIL image again

plt.ion()

def imshow(tensor, title=None):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    image = image.squeeze(0)      # remove the fake batch dimension
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) # pause a bit so that plots are updated


plt.figure()
imshow(style_img, title='Style Image')

plt.figure()
imshow(content_img, title='Content Image')

loss function

Content loss

Content loss is a function that represents the weighted content distance of a single layer. This function receives and processes input X The layer L of the network features FXL, which returns the input image X and content image C. The feature map (FCL) of the content image must be known so that the content distance can be calculated. We implement this function as a torch module, which has a constructor that accepts FCL as input. This distance ″ FXL − FCL ″ 2

It is the average square error between two feature graph sets, which can be calculated by nn.mselos.

We will add this content loss module directly to the convolution layer to calculate the content distance. In this way, each time the network receives an input image, the content loss will be calculated in the required layer, and because of auto grad, all gradients will be calculated. Now, to make the content loss layer transparent, we need to define a forward method to calculate the content loss and then return the input of that layer. The calculated loss is saved as a parameter of the module.

class ContentLoss(nn.Module):

    def __init__(self, target,):
        super(ContentLoss, self).__init__()
        # we 'detach' the target content from the tree used
        # to dynamically compute the gradient: this is a stated value,
        # not a variable. Otherwise the forward method of the criterion
        # will throw an error.
        self.target = target.detach()

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input

Note

Important details: Although this module is called ContentLoss, it is not a real PyTorch loss function. If you want to define content loss as a PyTorch loss function, you must create a PyTorch auto gradient function to calculate / implement the gradient manually in the backward method.

Style loss

The implementation of The style loss module is similar to that of the content loss module. As a transparent layer in the network to calculate the style loss of this layer, we need to calculate the gram matrix GXL . gram matrix is the result of multiplication of a given matrix and its transposition. In this application, the given matrix is the reshaped version of FXL, the characteristic graph of layer L. FXL is reshaped to form F^XL, a KxN matrix, in which K is the number of characteristic graphs of layer L, and N is the length of any vectorized characteristic graph FkXL. For example, the first row of F^XL corresponds to the first vectorized feature figure F1XL

.

Finally, the gram matrix must be standardized by dividing each element by the total number of elements in the matrix. This normalization is to offset the large N Dimension of F^XL

The fact that a matrix produces a larger value in a gram matrix. This very large value will cause the front layer (the layer before the pool layer) to exert an important influence on the gradient descent process. Style features tend to go deeper in the network, so this normalization step is extremely important.

def gram_matrix(input):
    a, b, c, d = input.size()  # a=batch size(=1)
    # b=number of feature maps
    # (c,d)=dimensions of a f. map (N=c*d)

    features = input.view(a * b, c * d)  # resise F_XL into \hat F_XL

    G = torch.mm(features, features.t())  # compute the gram product

    # we 'normalize' the values of the gram matrix
    # by dividing by the number of element in each feature maps.
    return G.div(a * b * c * d)

Now, style loss module looks almost the same as content loss module. Using GXL And GSL

Calculate the style distance of the mean square error between.

class StyleLoss(nn.Module):

    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target_feature).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input

Import model

Now we need to introduce a pre trained neural network. We will use a 19 layer VGG network, as we used in the paper.

The VGG implemented by PyTorch is a module, which is divided into two sub Sequential modules: features (including convolution layer and pooling layer) and classifier (including fully connected layer). We'll use features module because we need output from individual rollup layers to measure content loss and style loss. Some layers behave differently from evaluation during training, so we must use. eval() to set the network to evaluation mode.

cnn = models.vgg19(pretrained=True).features.to(device).eval()

In addition, the VGG network is trained on images normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] for each channel. We will use them to normalize the image, and then send the normalized image to the network for processing.

cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)

# Create a module to de normalize the input images so that we can simply drop them to nn.Sequential.
class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        # .view the mean and std to make them [C x 1 x 1] so that they can
        # directly work with image Tensor of shape [B x C x H x W].
        # B is batch size. C is number of channels. H is height and W is width.
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def forward(self, img):
        # normalize img
        return (img - self.mean) / self.std

Sequential module contains an ordered list of child modules. For example, vgg19.features contains a sequence (Conv2d, ReLU, MaxPool2d, Conv2d, ReLU )We need to add content loss layer and style loss layer immediately after the convolution layer they detect. To do this, we must create a new sequential module in which the content loss module and the style loss module are inserted correctly.

# The layer to calculate the depth required by style/content losses:
content_layers_default = ['conv_4']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
                               style_img, content_img,
                               content_layers=content_layers_default,
                               style_layers=style_layers_default):
    cnn = copy.deepcopy(cnn)

    # normalization module
    normalization = Normalization(normalization_mean, normalization_std).to(device)

    # just in order to have an iterable access to or list of content/syle
    # losses
    content_losses = []
    style_losses = []

    # assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
    # to put in modules that are supposed to be activated sequentially
    model = nn.Sequential(normalization)

    i = 0  # increment every time we see a conv
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            # The in-place version doesn't play very nicely with the ContentLoss
            # and StyleLoss we insert below. So we replace with out-of-place
            # ones here.
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

        model.add_module(name, layer)

        if name in content_layers:
            # add content loss:
            target = model(content_img).detach()
            content_loss = ContentLoss(target)
            model.add_module("content_loss_{}".format(i), content_loss)
            content_losses.append(content_loss)

        if name in style_layers:
            # add style loss:
            target_feature = model(style_img).detach()
            style_loss = StyleLoss(target_feature)
            model.add_module("style_loss_{}".format(i), style_loss)
            style_losses.append(style_loss)

    # now we trim off the layers after the last content and style losses
    for i in range(len(model) - 1, -1, -1):
        if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
            break

    model = model[:(i + 1)]

    return model, style_losses, content_losses

Next we choose to input the image. You can use a copy of the content image or a white noise image as the input image.

input_img = content_img.clone()
# If you want to use white noise, remove the comment from this line of code:
# input_img = torch.randn(content_img.data.size(), device=device)

# Add the original input image to the figure:
plt.figure()
imshow(input_img, title='Input Image')

gradient descent

The author of this algorithm suggests that we use L-BFGS algorithm to run gradient descent. Unlike training a network, what we want to train is to input images to minimize content/style losses. We will create a PyTorch L-BFGS optimizer optim.LBFGS and pass our image to it as a tensor to be optimized.

def get_input_optimizer(input_img):
    # this line to show that input is a parameter that requires a gradient
    optimizer = optim.LBFGS([input_img.requires_grad_()])
    return optimizer

Finally, we have to define a function that performs a natural transfer. For each iteration of the network, it gets an updated input and calculates new losses. We will run the backward method of each loss module to dynamically calculate their gradients. The optimizer needs a "close" function that reevaluates the module and returns the loss.

We have one last constraint to address. The network can try to optimize the input value, which exceeds the 0 to 1 tensor range of the image. We can solve this problem by correcting the input value to 0 to 1 each time the network runs.

def run_style_transfer(cnn, normalization_mean, normalization_std,
                       content_img, style_img, input_img, num_steps=300,
                       style_weight=1000000, content_weight=1):
    """Run the style transfer."""
    print('Building the style transfer model..')
    model, style_losses, content_losses = get_style_model_and_losses(cnn,
        normalization_mean, normalization_std, style_img, content_img)
    optimizer = get_input_optimizer(input_img)

    print('Optimizing..')
    run = [0]
    while run[0] <= num_steps:

        def closure():
            # correct the values of updated input image
            input_img.data.clamp_(0, 1)

            optimizer.zero_grad()
            model(input_img)
            style_score = 0
            content_score = 0

            for sl in style_losses:
                style_score += sl.loss
            for cl in content_losses:
                content_score += cl.loss

            style_score *= style_weight
            content_score *= content_weight

            loss = style_score + content_score
            loss.backward()

            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:4f} Content Loss: {:4f}'.format(
                    style_score.item(), content_score.item()))
                print()

            return style_score + content_score

        optimizer.step(closure)

    # a last correction...
    input_img.data.clamp_(0, 1)

    return input_img

Finally, we can run the algorithm

output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
                            content_img, style_img, input_img)

plt.figure()
imshow(output, title='Output Image')

# sphinx_gallery_thumbnail_number = 4
plt.ioff()
plt.show()

Tags: network

Posted on Wed, 11 Mar 2020 06:10:37 -0400 by johnnyblaze1980