Implementation process of classifier in so net

1, Data preparation

  this part of the code is located in. / data / modelnet_ shrec_ In The data read is pc_np (coordinates of points), surface_normal_np (normal vector), som_node_np (som node coordinates) and class_id (category). Then the data are enhanced, including rotation, perturbation, scale transformation and displacement. Returns the coordinates, normal vector, category, som node and the index of k nearest neighbors of each som node in the som node.

Two, model

  this part of the code is located in. / models/ It mainly includes two parts: coding network and classification network.

2.1 coding network

   the Encoder is shown in the blue background above. The specific definitions are in. / models/ In the forward of the Encoder class, you can see that its input is the coordinates of the point, the normal vector, the node coordinates, and the node to node knn index.

2.1.1 SOM layer

   as shown in the above figure, the first part is SOM layer. Through query in the BatchSOM class of. / util/som.py_ The topk() function implements mapping and returns mask, mask_row_max and min_idx.
(1) mask: its size is [B,kN,M]. The ith N*M line indicates whether the node is the ith nearest neighbor node of N points.

The Red 1 in the figure above indicates that the second node is the first nearest neighbor of the first point. A blue 1 indicates that the third node is the first nearest neighbor of the fourth point, while the lowest 1 indicates that the fifth node is the second nearest neighbor of the third point.
(2)mask_row_max: its size is [B,M], indicating whether each node has a nearest neighbor

Where 1 means that the ith node is the nearest neighbor of a point.
(3)min_idx: its size is [B,kN], and the ith N*1 line represents the index of the ith nearest neighbor node of N points.

It corresponds to the mask and represents the nearest neighbor index value.
  then obtain the centers of all points with each node as the nearest neighbor.

		self.mask, mask_row_max, min_idx = self.som_builder.query_topk(, k=self.opt.k)  # BxkNxnode_num, Bxnode_num
        mask_row_sum = torch.sum(self.mask, dim=1)  # BxM
        mask = self.mask.unsqueeze(1)  # Bx1xkNxM

        #Stack x and sn
        x_list, sn_list = [], []
        for i in range(self.opt.k):
        x_stack =, dim=2)  # B x C x kN
        sn_stack =, dim=2)# B x C x kN

        # Calculate the average coordinates of all points with each point as the nearest neighbor as the new node coordinates
        x_stack_data_unsqueeze =  # BxCxkNx1
        x_stack_data_masked = x_stack_data_unsqueeze * mask.float()  # BxCxkNxM
        cluster_mean = torch.sum(x_stack_data_masked, dim=2) / (mask_row_sum.unsqueeze(1).float()+1e-5)  # BxCxM, in order to prevent numerical instability, that is, no point takes this node as the nearest neighbor
        self.som_builder.node = cluster_mean
        self.som_node = self.som_builder.node

Then, each point is decentralized and spliced with sn as input.

		node_expanded =  # BxCx1xM
        self.centers = torch.sum(mask.float() * node_expanded, dim=3).detach()  # BxCxkN

        self.x_decentered = (x_stack - self.centers).detach()  # Bx3xkN
        x_augmented =, sn_stack), dim=1)  # Bx6xkN

2.1.2 first_pointnet

  this part of the code is in. / models/ Its essence is a residual network, and each layer is an equivariant layer structure, which is also defined in. / models/

(index_max is a c + + extension of cuda. I don't know what it means. It's too delicious 555)

2.1.3 knnlayer

  this part of the code is in. / models/ First, calculate the index of the first k distance centers of each center:

        coordinate_tensor =  # Bx3xM is the center of all points with the node as the nearest neighbor
        if precomputed_knn_I is not None:
            assert precomputed_knn_I.size()[2] >= K
            knn_I = precomputed_knn_I[:, :, 0:K]
            coordinate_Mx1 = coordinate_tensor.unsqueeze(3)  # Bx3xMx1
            coordinate_1xM = coordinate_tensor.unsqueeze(2)  # Bx3x1xM
            norm = torch.sum((coordinate_Mx1 - coordinate_1xM) ** 2, dim=1)  # BxMxM, each row corresponds to each coordinate - other coordinates
            knn_D, knn_I = torch.topk(norm, k=K, dim=2, largest=False, sorted=True)  # BxMxK distance and index of the first k closest distances from each center to other centers

Then, for each center, calculate the coordinates, mean and decentralized coordinate values of k nearest neighbor centers:

        neighbors = operations.knn_gather_wrapper(coordinate_tensor, knn_I)  # Bx3xMxK coordinates of the nearest k centers of each center
        if center_type == 'avg':  # If k average values are taken as the center
            neighbors_center = torch.mean(neighbors, dim=3, keepdim=True)  # Bx3xMx1 coordinate center of the nearest k centers of each center
        elif center_type == 'center':  # Center itself as the center
            neighbors_center = coordinate_tensor.unsqueeze(3)  # Bx3xMx1 coordinates of each center
        neighbors_decentered = (neighbors - neighbors_center).detach() # Bx3xMxK de center coordinates of the nearest k centers of each center
        neighbors_center = neighbors_center.squeeze(3).detach()  # Bx3xM center coordinates

Finally, the eigenvectors of k centers nearest to each center are obtained and used as the input of the volume layer, which is defined in the same file. The return values are the coordinates and eigenvectors of the center point.

        x_neighbors = operations.knn_gather_by_indexing(x, knn_I)  # BxCxMxK eigenvectors of the nearest k centers of each center
        x_augmented =, x_neighbors), dim=1)  # Bx(3+C)xMxK and central coordinate splicing

2.1.4 final_pointnet

  this part of the code is in. / models/ Is a conventional pointnet network structure. Get the global feature vector (global feature in the figure).

2.2 classifier

  this part of the code is located in models/ Its essence is a three-layer full connection layer, which outputs the probability of each category of point cloud. Finally, it is trained with cross entropy loss.

3, Testing and preservation

3.1 get_current_errors & visualizer.plot_current_errors

  this part of the code is in. / models/ Statistical prediction accuracy and visualization (loss time curve).

3.2 model.save_network

  save the model to the specified path.

    def save_network(self, network, network_label, epoch_label, gpu_id):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.opt.checkpoints_dir, save_filename), save_path)
        if gpu_id>=0 and torch.cuda.is_available():
            # torch.cuda.device(gpu_id)

3.3 model.update_learning_rate

  update learning rate.

def update_learning_rate(self, ratio):
        lr_clip = 0.00001

        # encoder
        lr_encoder = self.old_lr_encoder * ratio
        if lr_encoder < lr_clip:
            lr_encoder = lr_clip
        for param_group in self.optimizer_encoder.param_groups:
            param_group['lr'] = lr_encoder
        print('update encoder learning rate: %f -> %f' % (self.old_lr_encoder, lr_encoder))
        self.old_lr_encoder = lr_encoder

        # classifier
        lr_classifier = self.old_lr_classifier * ratio
        if lr_classifier < lr_clip:
            lr_classifier = lr_clip
        for param_group in self.optimizer_classifier.param_groups:
            param_group['lr'] = lr_classifier
        print('update classifier learning rate: %f -> %f' % (self.old_lr_classifier, lr_classifier))
        self.old_lr_classifier = lr_classifier

Tags: Python Computer Vision

Posted on Wed, 03 Nov 2021 19:20:48 -0400 by cybercog