[image deblurring] retreating coach to fine approach in single image deblurring

  Thesis address: https://arxiv.org/pdf/2108.05054.pdfhttps://arxiv.org/pdf/2108.05054.pdf

Code address: https://github.com/chosj95/MIMO-UNethttps://github.com/chosj95/MIMO-UNet


  The coarse to fine strategy has been widely used in the architecture design of single image deblurring network. The traditional method usually superimposes the sub network with the multi-scale input image to gradually improve the clarity of the image from the bottom sub network to the top sub network, which inevitably produces high computational cost. In order to realize fast and accurate deblurring network design, the coarse to fine strategy is reconsidered, and a multi input multi output unet network (MIMO unet) is proposed.

MIMO-UNet has three different characteristics.

Firstly, the single encoder of MIMO UNET uses multi-scale input images to reduce the difficulty of training.

Secondly, a single decoder of MIMO UNET outputs multiple deblurring images of different scales, and a single u-shaped network is used to simulate a multi-level connected u-shaped network.

Finally, asymmetric feature fusion is introduced to effectively merge multi-scale features.

Fig. 2 Comparison of coarse to fine deblurring networks.

  In this paper, the author re discusses the scheme from coarse to fine, and proposes a new deblurring network, called multi input multi output UNET (MIMO UNET), which can deal with multi-scale ambiguity with low computational complexity. The proposed MIMOUNet is a single u-shaped network based on encoder decoder, which has three different characteristics. (basically the same as described above)

First, a single decoder of MIMO UNET outputs multiple deblurring images, so the decoder is named multi output single decoder (MOSD). Although MOSD is simple, it can simulate the traditional network architecture composed of stacked subnetworks, and guide the decoder layer to gradually restore potential clear images in a course to fine manner.

Secondly, the single encoder of MIMO UNET adopts multi-scale input image; Therefore, the encoder is called multiple input single encoder (MISE).

Finally, asymmetric feature fusion (AFF) is introduced to effectively merge multi-scale features. Aff adopts the characteristics of different scales and combines the multi-scale information flow across the encoder and decoder to improve the deblurring performance.

Proposed method  

The proposed method is shown in Fig. 3. The encoder and decoder of MIMO UNET are composed of three encoder blocks (EB) and decoder blocks (DB).

Multi-input single encoder

It has been proved that different levels of blur in images can be better processed from multi-scale images  . In MIMO UNET, instead of subnetworks, EB takes fuzzy images of different scales as input. In other words, in addition to the reduced feature extracted from the above EB, the feature is extracted from the downsampled blurred image (such as figs. B2 and B3), and then the two features are combined.

By using the complementary information of reduced features and the features obtained from downsampled images, EB is expected to effectively deal with different image blurring. The use of multi-scale images as input to a single U-Net has also proved to be effective in other tasks, such as depth map super-resolution and object detection.

Figure 4, the module structure used in the network!

  First, a shallow convolution module (SCM) is used to extract features from the down sampled image, as shown in Fig. 4(a). Considering efficiency, two 3 are used × 3 and 1 × 1 convolution layer stack. Put the last 1 × The features of layer 1 are connected to the input Bk and an additional layer 1 is used × 1 the convolution layer further refines the connected features.

Specific code: (for BasicConv and ResBlock, please refer to layers.py later)

# Figure 4 (a) SCM module
class SCM(nn.Module):
    def __init__(self, out_plane):
        super(SCM, self).__init__()
        self.main = nn.Sequential(
            BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True),
            BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
            BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
            BasicConv(out_plane // 2, out_plane-3, kernel_size=1, stride=1, relu=True)

        self.conv = BasicConv(out_plane, out_plane, kernel_size=1, stride=1, relu=False)

    def forward(self, x):
        x = torch.cat([x, self.main(x)], dim=1)
        return self.conv(x)

A feature attention module (FAM) is used to actively emphasize or suppress features on previous scales, and learn the spatial / channel importance of features from SCM. As shown in Figure 4 (b).  


  Specific code: (for BasicConv and ResBlock, please refer to layers.py later)

# Figure 4 (b) feature attention
class FAM(nn.Module):
    def __init__(self, channel):
        super(FAM, self).__init__()
        self.merge = BasicConv(channel, channel, kernel_size=3, stride=1, relu=False)

    def forward(self, x1, x2):
        x = x1 * x2
        out = x1 + self.merge(x)
        return out

 Multi-output single decoder

In MIMO UNET, different DBs have different size characteristic graphs. The author believes that these multi-scale characteristic graphs can be used to simulate multi stacked sub networks. Different from the traditional intermediate supervision of sub networks from coarse to fine networks, intermediate supervision is applied to each DB.

Specific manifestations:

  Since the output of DB is a feature map rather than an image, the mapping function o is necessary to generate an intermediate output image, in which a single convolution layer is used. The formula representation is shown by the red arrow in the figure below.


Asymmetric feature fusion

In most traditional coarse to fine image deblurring networks, only the features from coarser scale subnetworks are used for finer scale subnetworks, which makes the information flow inflexible. A special method is to cascade the whole network horizontally or vertically, allowing top-down and bottom-up information flow. Inspired by the tight connection between intra scale features, we propose an asymmetric feature fusion (AFF) module, as shown in Figure 4(c), to allow information flow from different scales in a single U-Net. Each aff takes the output of all EB as the input and uses convolution layer to combine multi-scale features.  

The specific expression is shown in formula 6:


  Specific code: (for BasicConv and ResBlock, please refer to layers.py later)

# Figure 4 (c) in AFF module paper
class AFF(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(AFF, self).__init__()
        self.conv = nn.Sequential(
            BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True),
            BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)

    def forward(self, x1, x2, x4):
        x = torch.cat([x1, x2, x4], dim=1)
        return self.conv(x)

 Loss function: L1 loss


 criterion = torch.nn.L1Loss()

train.py training code  

  Recent studies have also shown that in addition to the content loss of performance improvement, there are auxiliary loss items. In image enhancement and restoration tasks, the auxiliary loss term of minimizing the distance between input and output in feature space has been widely used and shown effective results.

  Since the purpose of deblurring is to recover the lost high-frequency components, it is very important to reduce the difference in frequency space. Therefore, a multi-scale frequency reconstruction (MSFR) loss function is proposed.


Where, the corresponding codes of formulas 8 and 9 are:

MIMO UNET network code: (note corresponds to the mark in the figure of the paper)

class MIMOUNet(nn.Module):
    def __init__(self, num_res=8):
        super(MIMOUNet, self).__init__()

        base_channel = 32

        self.Encoder = nn.ModuleList([
            EBlock(base_channel, num_res),
            EBlock(base_channel*2, num_res),
            EBlock(base_channel*4, num_res),

        self.feat_extract = nn.ModuleList([
            BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
            BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)

        self.Decoder = nn.ModuleList([
            DBlock(base_channel * 4, num_res),
            DBlock(base_channel * 2, num_res),
            DBlock(base_channel, num_res)

        self.Convs = nn.ModuleList([
            BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
            BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),

        self.ConvsOut = nn.ModuleList(
                BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
                BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),

        self.AFFs = nn.ModuleList([
            AFF(base_channel * 7, base_channel*1),
            AFF(base_channel * 7, base_channel*2)

        self.FAM1 = FAM(base_channel * 4)
        self.SCM1 = SCM(base_channel * 4)
        self.FAM2 = FAM(base_channel * 2)
        self.SCM2 = SCM(base_channel * 2)

    def forward(self, x):
        x_2 = F.interpolate(x, scale_factor=0.5) # Down sampling B2
        x_4 = F.interpolate(x_2, scale_factor=0.5) # Down sampling B3
        z2 = self.SCM2(x_2) # B2 through SCM_ two
        z4 = self.SCM1(x_4) # B3 through SCM_ three

        outputs = list()

        x_ = self.feat_extract[0](x) # Conv3x3 
        res1 = self.Encoder[0](x_) # Code EB1

        z = self.feat_extract[1](res1) # Conv3x3 
        z = self.FAM2(z, z2)  # SCM_2 fusion before EB2 
        res2 = self.Encoder[1](z)  # EB2

        z = self.feat_extract[2](res2) # Conv3x3
        z = self.FAM1(z, z4) # SCM_3 fusion before EB3
        z = self.Encoder[2](z) # EB3

        z12 = F.interpolate(res1, scale_factor=0.5) # Down sampling to AFF2
        z21 = F.interpolate(res2, scale_factor=2)   # Up sampling to AFF1
        z42 = F.interpolate(z, scale_factor=2)      # Up sampling to AFF2
        z41 = F.interpolate(z42, scale_factor=2)    # Up sampling to AFF1

        res2 = self.AFFs[1](z12, res2, z42) # AFF_2 Fusion
        res1 = self.AFFs[0](res1, z21, z41) # AFF_1 fusion 

        z = self.Decoder[0](z)  # DB3
        z_ = self.ConvsOut[0](z) # The characteristic graph of h/4 x w/4 x 3 is generated by convolution
        z = self.feat_extract[3](z) # ConvTranspose 4x4 transpose convolution
        outputs.append(z_+x_4) # B3 + h/4 x w/4 x 3 ==> S^_3 (Element-wise summation)

        z = torch.cat([z, res2], dim=1)
        z = self.Convs[0](z)  # Conv1x1 
        z = self.Decoder[1](z) # DB2
        z_ = self.ConvsOut[1](z) # The characteristic graph of h/2 x w/2 x 3 is generated by convolution
        z = self.feat_extract[4](z) # ConvTranspose 4x4 transpose convolution
        outputs.append(z_+x_2)  # B2 + h/2 x w/2 x 3 ==> S^_2 (Element-wise summation)

        z = torch.cat([z, res1], dim=1)
        z = self.Convs[1](z)   # conv 1x1
        z = self.Decoder[2](z) # DB1
        z = self.feat_extract[5](z)  # Generate hxwx3 through conv3x3
        outputs.append(z+x)  # B1 + h x w x 3 ==> S^_1  

        return outputs  # Return to s^_ 3 S^_ 2 S^_ one 

Experimental results:


import torch
import torch.nn as nn

# Conv2d -> BN -> ReLU  or ConvTranspose2d -> BN ->ReLU
class BasicConv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
        super(BasicConv, self).__init__()
        if bias and norm:
            bias = False

        padding = kernel_size // 2
        layers = list()
        if transpose:
            padding = kernel_size // 2 -1
            layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
                nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
        if norm:
        if relu:
        self.main = nn.Sequential(*layers)

    def forward(self, x):
        return self.main(x)

# Residual block
class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ResBlock, self).__init__()
        self.main = nn.Sequential(
            BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
            BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)

    def forward(self, x):
        return self.main(x) + x

Tags: neural networks Deep Learning

Posted on Mon, 04 Oct 2021 15:53:09 -0400 by consolestrat