1, Overall structure
It can be divided into three parts in order:
1,stem net
Input IMG and output feature map. After getting the feature map of this size, HRNet always keeps the picture of this size
2. Four phases of HRNet
(1) The multi-scale characteristic diagram generated by each stage is configured as shown in Table 1.
(2) There is a transition structure at the connection of stages, which is used to connect between different stages to complete the corresponding size of channels and feature map.
Table 1
3,segment head
Concatenate the four scale features output by stage4 and add num_ channels->num_ Classes layer to get the segmentation results
2, Building blocks applied in HRNet
1. The ordinary 3 * 3 convolution has the following structure:
def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
2. The basic block has the following structure:
class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out
3. Residual block of three layers
expansion parameter, which is used to control the number of input and output channels of convolution.
class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out
3, Specific module
1. HRNet core module class: high resolution module
Realize the branch parallel multi scale feature extraction and end multi scale feature fusion through upsample/downsample in the red box in the figure below
class HighResolutionModule(nn.Module): def __init__(self, num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method, multi_scale_output=True): super(HighResolutionModule, self).__init__() self._check_branches( num_branches, blocks, num_blocks, num_inchannels, num_channels) self.num_inchannels = num_inchannels self.fuse_method = fuse_method self.num_branches = num_branches self.multi_scale_output = multi_scale_output self.branches = self._make_branches( num_branches, blocks, num_blocks, num_channels) self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU(False)
2,check_branches() function
This function is used to check num in the high-resolution module_ Whether the values of branches (int type), len (num_channels (the element inside is int)) and Len (num_channels (the element inside is int)) are equal.
def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): if num_branches != len(num_blocks): error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( num_branches, len(num_blocks)) logger.error(error_msg) raise ValueError(error_msg) if num_branches != len(num_channels): error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( num_branches, len(num_channels)) logger.error(error_msg) raise ValueError(error_msg) if num_branches != len(num_inchannels): error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( num_branches, len(num_inchannels)) logger.error(error_msg) raise ValueError(error_msg)
3,make_one_branch function
Its function is to create a new branch, as shown in the figure
def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): downsample = None if stride != 1 or \ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM), ) layers = [] layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)) self.num_inchannels[branch_index] = \ num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) return nn.Sequential(*layers)
make_ The branches function is to see how many branches are in each stage, and then call several times_ make_one_branch function.
4,forward
def forward(self, x): if self.num_branches == 1: return [self.branches[0](x[0])] for i in range(self.num_branches): x[i] = self.branches[i](x[i]) x_fuse = [] for i in range(len(self.fuse_layers)): y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) for j in range(1, self.num_branches): if i == j: y = y + x[j] else: y = y + self.fuse_layers[i][j](x[j]) x_fuse.append(self.relu(y)) return x_fuse
5. Constructing multi-scale feature fusion layer: fuse_layer function
In the code, i in the double loop represents the currently fused branch. As shown in the figure above, when i=0, all feature maps are fused to the feature maps of the branch 0, and j represents the branch index corresponding to the fused feature maps
def _make_fuse_layers(self): if self.num_branches == 1: return None num_branches = self.num_branches num_inchannels = self.num_inchannels fuse_layers = [] for i in range(num_branches if self.multi_scale_output else 1): fuse_layer = [] for j in range(num_branches): if j > i: fuse_layer.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), nn.BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM), nn.Upsample(scale_factor=2**(j-i), mode='nearest'))) elif j == i: fuse_layer.append(None) else: conv3x3s = [] for k in range(i-j): if k == i - j - 1: num_outchannels_conv3x3 = num_inchannels[i] conv3x3s.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), nn.BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM))) else: num_outchannels_conv3x3 = num_inchannels[j] conv3x3s.append(nn.Sequential( nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), nn.BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM), nn.ReLU(False))) fuse_layer.append(nn.Sequential(*conv3x3s)) fuse_layers.append(nn.ModuleList(fuse_layer)) return nn.ModuleList(fuse_layers)
6,transition_layers function (the branch with the cross in the figure above)
transition layer completes the two transformations required for the connection between stage s
(1) input channels conversion
(2)feature size downsample
def _make_transition_layer( self, num_channels_pre_layer, num_channels_cur_layer): num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) transition_layers = [] for i in range(num_branches_cur): if i < num_branches_pre: if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append(nn.Sequential( nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), nn.BatchNorm2d( num_channels_cur_layer[i], momentum=BN_MOMENTUM), nn.ReLU(inplace=True))) else: transition_layers.append(None) else: conv3x3s = [] for j in range(i+1-num_branches_pre): inchannels = num_channels_pre_layer[-1] outchannels = num_channels_cur_layer[i] \ if j == i-num_branches_pre else inchannels conv3x3s.append(nn.Sequential( nn.Conv2d( inchannels, outchannels, 3, 2, 1, bias=False), nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), nn.ReLU(inplace=True))) transition_layers.append(nn.Sequential(*conv3x3s)) return nn.ModuleList(transition_layers)
7. Build the layer of stage1_ make_ layer()
stage1 generates a 1/4 feature map without a branch structure, using the same method as resnet_ make_ The layer() function builds the layer
8. Building the layer of stage 2/3/4_ make_ stage
stage 2/3/4 is the core structure of HRNet, which uses the core class HighResolutionModule, including make_branches construction and features_ make_fuse_layers module