Python deep learning practical course: UNet semantic segmentation Network - Zhihu

GitHub Jack-Cherish/PythonPark Have included, have technical articles, collated learning materials, first-line factory interview experience sharing, etc., welcome Star and perfect.

1, Foreword

This article belongs to a series of in-depth learning semantic segmentation tutorials of Python.

The contents of this series are:

  • Basic use of Python
  • Explanation of semantic segmentation algorithm

If you don't understand the principle of semantic segmentation and the construction of development environment, please see the previous article in this series of tutorials< Python in-depth learning practical course (I): semantic segmentation foundation and environment construction>.

The development environment of this paper adopts the Windows environment built in the previous article. The environment is as follows:

Development environment: Windows

Development language: Python 3.7.4

Framework version: Python 1.3.0

CUDA: 10.2

cuDNN: 7.6.0

This paper mainly explains the network structure of UNet and the code writing of the corresponding code.

PS: all the codes in this article can be downloaded from my github. Welcome to Follow and Star: Click to view

2, UNet network structure

In the field of semantic segmentation, the first work of deep learning based semantic segmentation algorithm is FCN (full collaborative networks for semantic segmentation), while UNet follows the principle of FCN and makes corresponding improvements to adapt to the simple segmentation of small samples.

UNet paper address: Click to view

To study a deep learning algorithm, we can first look at the network structure, understand the network structure, and then lose calculation method, training method, etc. This paper focuses on the network structure of UNet, and other contents will be explained in the following chapters.

1. Network structure principle

UNet was first published in 2015's MICCAI conference. In more than four years, the number of papers cited has reached more than 9700.

UNet has become the baseline of most medical image semantic segmentation tasks, and inspired a large number of researchers to study the U-shaped network structure, and published a number of papers based on the improved methods of UNet network structure.

The two main characteristics of UNet network structure are: U-type network structure and Skip Connection.



UNet is a symmetrical network structure, the left side is the lower sampling, the right side is the upper sampling.

According to the function, a series of down sampling operations on the left can be called encoder, and a series of up sampling operations on the right can be called decoder.

The four gray parallel lines in the middle of Skip Connection are the feature map in the process of upper sampling and lower sampling.

The fusion operation used in Skip Connection is also very simple. It is to stack the channels of feature map, commonly known as Concat.

The Concat operation is also well understood. For example, a Book A with a size of 10cm*10cm and a thickness of 3cm, and a Book B with a size of 10cm*10cm and a thickness of 4cm.

Stack books A and B with the edges aligned. In this way, we get A stack of books with the size of 10cm*10cm and the thickness of 7cm, similar to this:



This kind of "stacking" operation is Concat.

Similarly, for feature map, a feature map with a size of 256 * 256 * 64, that is, the w (width) of the feature map is 256, the h (height) is 256, and the c (number of channels) is 64. After Concat with a feature map with a size of 256 * 256 * 32, you will get a feature map with a size of 256 * 256 * 96.

In practice, the size of two feature maps fused by Concat is not necessarily the same, for example, the feature map of 256 * 256 * 64 and the feature map of 240 * 240 * 32 are used for Concat.

At this time, there are two ways:

First, cut the feature map of 256 * 256 * 64 to 240 * 240 * 64. For example, discard 8 pixel s from the top, the bottom, the left, the right, and the left. After cutting, perform Concat to get 240 * 240 * 96 feature map.

The second is to add the feature map of 240 * 240 * 32. The feature map of 256 * 256 * 32 is added. For example, 8 pixels are added to the top, 8 pixels are added to the bottom, 8 pixels are added to the top, and then Concat is done to get the feature map of 256 * 256 * 96.

The Concat scheme adopted by UNet is the second one. The small feature map is padded. The way of padding is to fill 0, which is a normal constant.

2. Code

Some friends may not know much about python, so they recommend an official quick start tutorial. In an hour, you can master some basic concepts and Python code writing methods.

Official basis of Python: Click to view

We will divide the whole UNet network into several modules for explanation.

DoubleConv module:

Let's first look at two consecutive convolution operations.



It can be seen from the UNet network that no matter whether it is the down sampling process or the up sampling process, each layer will carry out two consecutive convolution operations, which are repeated many times in the UNet network, and a DoubleConv module can be written separately:

import torch.nn as nn

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

Explain the above Python code: torch.nn.Sequential Is a timing container, and Modules are added to the container in the order they are passed in. For example, the operation sequence of the above code: convolution - > BN - > relu - > convolution - > BN - > relu.

In of DoubleConv module_ Channels and out_channels can be set flexibly for extended use.

Network as shown above, in_channels set to 1, out_channels is 64.

The input image size is 572 * 572, after 3 * 3 convolution with step size of 1 and padding of 0, the feature map of 570 * 570 is obtained, and then after one convolution, the feature map of 568 * 568 is obtained.

Calculation formula: O=(H − F+2 × P)/S+1

H is the size of the input feature map, O is the size of the output feature map, F is the size of the convolution kernel, P is the size of padding, and S is the step size.

Down module:



There are four sub sampling processes in UNet network, and the modular code is as follows:

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

The code here is very simple. It is a maxpool pooling layer, which is used for down sampling, followed by a DoubleConv module.

So far, the code of the lower sampling process in the left half of the UNet network has been written, followed by the upper sampling process in the right half.

Up module:

Of course, the most used up sampling process is up sampling. In addition to the normal up sampling operation, there is also feature fusion.



The implementation of this code is a little more complicated:

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

The code is a little more complicated. We can look at it separately. First, it is__ init__ The upsampling method and convolution defined in the initialization function adopt DoubleConv. For up sampling, two methods are defined: Upsample and ConvTranspose2d, which are bilinear interpolation and deconvolution.

Bilinear interpolation is well understood, schematic diagram:



A friend who is familiar with bilinear interpolation should be familiar with this picture. Simply speaking, we know the coordinates of Q11, Q12, Q21 and Q22. We can find R1 through Q11 and Q21, R2 through Q12 and Q22, and P through R1 and R2. This process is bilinear interpolation.

For a feature map, it is actually to add points in the middle of pixels. The value of the points to be added is determined by the value of adjacent pixels.

Deconvolution, as the name implies, is deconvolution. Convolution is to make the feature map smaller and smaller. Deconvolution is to make the feature map larger and larger



The blue below is the original image, and the dotted white box around it is the padding result, usually 0. The green above is the convoluted image.

This diagram is a feature map process from 2 * 2 feature map to 4 * 4 feature map.

In the forward propagation function, x1 receives the data of up sampling and x2 receives the data of feature fusion. The feature fusion method is that, as mentioned above, the small feature map is padded first, and then concat.

OutConv module:

With the above-mentioned DoubleConv module, Down module and Up module, you can spell the main network structure of UNet. The output of UNet network needs to integrate the output channels according to the number of partitions, and the results are as follows:



The operation is very simple, that is, the transformation of channel. The figure above shows the case of channel 2.

Although this operation is very simple, it should be called once. In order to be beautiful and clean, it should also be encapsulated.

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

So far, the modules used in UNet network have been written. We can put all the above module codes into one unet_parts.py File, and then create unet_model.py According to the UNet network structure, set the number of input and output channels and call sequence of each module, and write the following code:

""" Full assembly of the parts to form the complete network """
"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""

import torch.nn.functional as F

from unet_parts import *


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512, bilinear)
        self.up2 = Up(512, 256, bilinear)
        self.up3 = Up(256, 128, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
    
if __name__ == '__main__':
    net = UNet(n_channels=3, n_classes=1)
    print(net)

Use the command python unet_model.py , if there are no errors, you will get the following results:

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (down2): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
          (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (down3): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
          (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (down4): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))
          (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
          (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (up1): Up(
    (up): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (up2): Up(
    (up): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (up3): Up(
    (up): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (up4): Up(
    (up): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (outc): OutConv(
    (conv): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)

After the network construction, the next step is to use the network for training. The specific implementation will be explained in the next article of this series of tutorials.

3, Summary

  • This paper mainly introduces the network structure of UNet, and makes a modular analysis of UNet.
  • The next article explains how to use UNet network to write training code.
Look at the habit, WeChat official account search. [Jack Cui AI] focuses on a stalker who is crawling on the Internet

Tags: network Python github Windows

Posted on Sun, 24 May 2020 07:08:47 -0400 by jpbellavance