Thesis address: https://arxiv.org/pdf/2108.05054.pdfhttps://arxiv.org/pdf/2108.05054.pdf
Code address: https://github.com/chosj95/MIMO-UNethttps://github.com/chosj95/MIMO-UNet
The coarse to fine strategy has been widely used in the architecture design of single image deblurring network. The traditional method usually superimposes the sub network with the multi-scale input image to gradually improve the clarity of the image from the bottom sub network to the top sub network, which inevitably produces high computational cost. In order to realize fast and accurate deblurring network design, the coarse to fine strategy is reconsidered, and a multi input multi output unet network (MIMO unet) is proposed.
MIMO-UNet has three different characteristics.
Firstly, the single encoder of MIMO UNET uses multi-scale input images to reduce the difficulty of training.
Secondly, a single decoder of MIMO UNET outputs multiple deblurring images of different scales, and a single u-shaped network is used to simulate a multi-level connected u-shaped network.
Finally, asymmetric feature fusion is introduced to effectively merge multi-scale features.
Fig. 2 Comparison of coarse to fine deblurring networks.
In this paper, the author re discusses the scheme from coarse to fine, and proposes a new deblurring network, called multi input multi output UNET (MIMO UNET), which can deal with multi-scale ambiguity with low computational complexity. The proposed MIMOUNet is a single u-shaped network based on encoder decoder, which has three different characteristics. (basically the same as described above)
First, a single decoder of MIMO UNET outputs multiple deblurring images, so the decoder is named multi output single decoder (MOSD). Although MOSD is simple, it can simulate the traditional network architecture composed of stacked subnetworks, and guide the decoder layer to gradually restore potential clear images in a course to fine manner.
Secondly, the single encoder of MIMO UNET adopts multi-scale input image; Therefore, the encoder is called multiple input single encoder (MISE).
Finally, asymmetric feature fusion (AFF) is introduced to effectively merge multi-scale features. Aff adopts the characteristics of different scales and combines the multi-scale information flow across the encoder and decoder to improve the deblurring performance.
Proposed method
The proposed method is shown in Fig. 3. The encoder and decoder of MIMO UNET are composed of three encoder blocks (EB) and decoder blocks (DB).
Multi-input single encoder
It has been proved that different levels of blur in images can be better processed from multi-scale images . In MIMO UNET, instead of subnetworks, EB takes fuzzy images of different scales as input. In other words, in addition to the reduced feature extracted from the above EB, the feature is extracted from the downsampled blurred image (such as figs. B2 and B3), and then the two features are combined.
By using the complementary information of reduced features and the features obtained from downsampled images, EB is expected to effectively deal with different image blurring. The use of multi-scale images as input to a single U-Net has also proved to be effective in other tasks, such as depth map super-resolution and object detection.
Figure 4, the module structure used in the network!
First, a shallow convolution module (SCM) is used to extract features from the down sampled image, as shown in Fig. 4(a). Considering efficiency, two 3 are used × 3 and 1 × 1 convolution layer stack. Put the last 1 × The features of layer 1 are connected to the input Bk and an additional layer 1 is used × 1 the convolution layer further refines the connected features.
Specific code: (for BasicConv and ResBlock, please refer to layers.py later)
# Figure 4 (a) SCM module class SCM(nn.Module): def __init__(self, out_plane): super(SCM, self).__init__() self.main = nn.Sequential( BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True), BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True), BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True), BasicConv(out_plane // 2, out_plane-3, kernel_size=1, stride=1, relu=True) ) self.conv = BasicConv(out_plane, out_plane, kernel_size=1, stride=1, relu=False) def forward(self, x): x = torch.cat([x, self.main(x)], dim=1) return self.conv(x)
A feature attention module (FAM) is used to actively emphasize or suppress features on previous scales, and learn the spatial / channel importance of features from SCM. As shown in Figure 4 (b).
Specific code: (for BasicConv and ResBlock, please refer to layers.py later)
# Figure 4 (b) feature attention class FAM(nn.Module): def __init__(self, channel): super(FAM, self).__init__() self.merge = BasicConv(channel, channel, kernel_size=3, stride=1, relu=False) def forward(self, x1, x2): x = x1 * x2 out = x1 + self.merge(x) return out
Multi-output single decoder
In MIMO UNET, different DBs have different size characteristic graphs. The author believes that these multi-scale characteristic graphs can be used to simulate multi stacked sub networks. Different from the traditional intermediate supervision of sub networks from coarse to fine networks, intermediate supervision is applied to each DB.
Specific manifestations:
Since the output of DB is a feature map rather than an image, the mapping function o is necessary to generate an intermediate output image, in which a single convolution layer is used. The formula representation is shown by the red arrow in the figure below.
Asymmetric feature fusion
In most traditional coarse to fine image deblurring networks, only the features from coarser scale subnetworks are used for finer scale subnetworks, which makes the information flow inflexible. A special method is to cascade the whole network horizontally or vertically, allowing top-down and bottom-up information flow. Inspired by the tight connection between intra scale features, we propose an asymmetric feature fusion (AFF) module, as shown in Figure 4(c), to allow information flow from different scales in a single U-Net. Each aff takes the output of all EB as the input and uses convolution layer to combine multi-scale features.
The specific expression is shown in formula 6:
Specific code: (for BasicConv and ResBlock, please refer to layers.py later)
# Figure 4 (c) in AFF module paper class AFF(nn.Module): def __init__(self, in_channel, out_channel): super(AFF, self).__init__() self.conv = nn.Sequential( BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True), BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) ) def forward(self, x1, x2, x4): x = torch.cat([x1, x2, x4], dim=1) return self.conv(x)
Loss function: L1 loss
criterion = torch.nn.L1Loss()
train.py training code
Recent studies have also shown that in addition to the content loss of performance improvement, there are auxiliary loss items. In image enhancement and restoration tasks, the auxiliary loss term of minimizing the distance between input and output in feature space has been widely used and shown effective results.
Since the purpose of deblurring is to recover the lost high-frequency components, it is very important to reduce the difference in frequency space. Therefore, a multi-scale frequency reconstruction (MSFR) loss function is proposed.
Where, the corresponding codes of formulas 8 and 9 are:
MIMO UNET network code: (note corresponds to the mark in the figure of the paper)
class MIMOUNet(nn.Module): def __init__(self, num_res=8): super(MIMOUNet, self).__init__() base_channel = 32 self.Encoder = nn.ModuleList([ EBlock(base_channel, num_res), EBlock(base_channel*2, num_res), EBlock(base_channel*4, num_res), ]) self.feat_extract = nn.ModuleList([ BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) ]) self.Decoder = nn.ModuleList([ DBlock(base_channel * 4, num_res), DBlock(base_channel * 2, num_res), DBlock(base_channel, num_res) ]) self.Convs = nn.ModuleList([ BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), ]) self.ConvsOut = nn.ModuleList( [ BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), ] ) self.AFFs = nn.ModuleList([ AFF(base_channel * 7, base_channel*1), AFF(base_channel * 7, base_channel*2) ]) self.FAM1 = FAM(base_channel * 4) self.SCM1 = SCM(base_channel * 4) self.FAM2 = FAM(base_channel * 2) self.SCM2 = SCM(base_channel * 2) def forward(self, x): x_2 = F.interpolate(x, scale_factor=0.5) # Down sampling B2 x_4 = F.interpolate(x_2, scale_factor=0.5) # Down sampling B3 z2 = self.SCM2(x_2) # B2 through SCM_ two z4 = self.SCM1(x_4) # B3 through SCM_ three outputs = list() x_ = self.feat_extract[0](x) # Conv3x3 res1 = self.Encoder[0](x_) # Code EB1 z = self.feat_extract[1](res1) # Conv3x3 z = self.FAM2(z, z2) # SCM_2 fusion before EB2 res2 = self.Encoder[1](z) # EB2 z = self.feat_extract[2](res2) # Conv3x3 z = self.FAM1(z, z4) # SCM_3 fusion before EB3 z = self.Encoder[2](z) # EB3 z12 = F.interpolate(res1, scale_factor=0.5) # Down sampling to AFF2 z21 = F.interpolate(res2, scale_factor=2) # Up sampling to AFF1 z42 = F.interpolate(z, scale_factor=2) # Up sampling to AFF2 z41 = F.interpolate(z42, scale_factor=2) # Up sampling to AFF1 res2 = self.AFFs[1](z12, res2, z42) # AFF_2 Fusion res1 = self.AFFs[0](res1, z21, z41) # AFF_1 fusion z = self.Decoder[0](z) # DB3 z_ = self.ConvsOut[0](z) # The characteristic graph of h/4 x w/4 x 3 is generated by convolution z = self.feat_extract[3](z) # ConvTranspose 4x4 transpose convolution outputs.append(z_+x_4) # B3 + h/4 x w/4 x 3 ==> S^_3 (Element-wise summation) z = torch.cat([z, res2], dim=1) z = self.Convs[0](z) # Conv1x1 z = self.Decoder[1](z) # DB2 z_ = self.ConvsOut[1](z) # The characteristic graph of h/2 x w/2 x 3 is generated by convolution z = self.feat_extract[4](z) # ConvTranspose 4x4 transpose convolution outputs.append(z_+x_2) # B2 + h/2 x w/2 x 3 ==> S^_2 (Element-wise summation) z = torch.cat([z, res1], dim=1) z = self.Convs[1](z) # conv 1x1 z = self.Decoder[2](z) # DB1 z = self.feat_extract[5](z) # Generate hxwx3 through conv3x3 outputs.append(z+x) # B1 + h x w x 3 ==> S^_1 return outputs # Return to s^_ 3 S^_ 2 S^_ one
Experimental results:
layers.py
import torch import torch.nn as nn # Conv2d -> BN -> ReLU or ConvTranspose2d -> BN ->ReLU class BasicConv(nn.Module): def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): super(BasicConv, self).__init__() if bias and norm: bias = False padding = kernel_size // 2 layers = list() if transpose: padding = kernel_size // 2 -1 layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) else: layers.append( nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) if norm: layers.append(nn.BatchNorm2d(out_channel)) if relu: layers.append(nn.ReLU(inplace=True)) self.main = nn.Sequential(*layers) def forward(self, x): return self.main(x) # Residual block class ResBlock(nn.Module): def __init__(self, in_channel, out_channel): super(ResBlock, self).__init__() self.main = nn.Sequential( BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True), BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) ) def forward(self, x): return self.main(x) + x