1, Data preparation
this part of the code is located in. / data / modelnet_ shrec_ In loader.py. The data read is pc_np (coordinates of points), surface_normal_np (normal vector), som_node_np (som node coordinates) and class_id (category). Then the data are enhanced, including rotation, perturbation, scale transformation and displacement. Returns the coordinates, normal vector, category, som node and the index of k nearest neighbors of each som node in the som node.
Two, model
this part of the code is located in. / models/classifier.py. It mainly includes two parts: coding network and classification network.
2.1 coding network
the Encoder is shown in the blue background above. The specific definitions are in. / models/networks.py. In the forward of the Encoder class, you can see that its input is the coordinates of the point, the normal vector, the node coordinates, and the node to node knn index.
2.1.1 SOM layer
as shown in the above figure, the first part is SOM layer. Through query in the BatchSOM class of. / util/som.py_ The topk() function implements mapping and returns mask, mask_row_max and min_idx.
(1) mask: its size is [B,kN,M]. The ith N*M line indicates whether the node is the ith nearest neighbor node of N points.
The Red 1 in the figure above indicates that the second node is the first nearest neighbor of the first point. A blue 1 indicates that the third node is the first nearest neighbor of the fourth point, while the lowest 1 indicates that the fifth node is the second nearest neighbor of the third point.
(2)mask_row_max: its size is [B,M], indicating whether each node has a nearest neighbor
Where 1 means that the ith node is the nearest neighbor of a point.
(3)min_idx: its size is [B,kN], and the ith N*1 line represents the index of the ith nearest neighbor node of N points.
It corresponds to the mask and represents the nearest neighbor index value.
then obtain the centers of all points with each node as the nearest neighbor.
self.mask, mask_row_max, min_idx = self.som_builder.query_topk(x.data, k=self.opt.k) # BxkNxnode_num, Bxnode_num mask_row_sum = torch.sum(self.mask, dim=1) # BxM mask = self.mask.unsqueeze(1) # Bx1xkNxM #Stack x and sn x_list, sn_list = [], [] for i in range(self.opt.k): x_list.append(x) sn_list.append(sn) x_stack = torch.cat(tuple(x_list), dim=2) # B x C x kN sn_stack = torch.cat(tuple(sn_list), dim=2)# B x C x kN # Calculate the average coordinates of all points with each point as the nearest neighbor as the new node coordinates x_stack_data_unsqueeze = x_stack.data.unsqueeze(3) # BxCxkNx1 x_stack_data_masked = x_stack_data_unsqueeze * mask.float() # BxCxkNxM cluster_mean = torch.sum(x_stack_data_masked, dim=2) / (mask_row_sum.unsqueeze(1).float()+1e-5) # BxCxM, in order to prevent numerical instability, that is, no point takes this node as the nearest neighbor self.som_builder.node = cluster_mean self.som_node = self.som_builder.node
Then, each point is decentralized and spliced with sn as input.
node_expanded = self.som_node.data.unsqueeze(2) # BxCx1xM self.centers = torch.sum(mask.float() * node_expanded, dim=3).detach() # BxCxkN self.x_decentered = (x_stack - self.centers).detach() # Bx3xkN x_augmented = torch.cat((self.x_decentered, sn_stack), dim=1) # Bx6xkN
2.1.2 first_pointnet
this part of the code is in. / models/layers.py. Its essence is a residual network, and each layer is an equivariant layer structure, which is also defined in. / models/layers.py.
(index_max is a c + + extension of cuda. I don't know what it means. It's too delicious 555)
2.1.3 knnlayer
this part of the code is in. / models/layers.py. First, calculate the index of the first k distance centers of each center:
coordinate_tensor = coordinate.data # Bx3xM is the center of all points with the node as the nearest neighbor if precomputed_knn_I is not None: assert precomputed_knn_I.size()[2] >= K knn_I = precomputed_knn_I[:, :, 0:K] else: coordinate_Mx1 = coordinate_tensor.unsqueeze(3) # Bx3xMx1 coordinate_1xM = coordinate_tensor.unsqueeze(2) # Bx3x1xM norm = torch.sum((coordinate_Mx1 - coordinate_1xM) ** 2, dim=1) # BxMxM, each row corresponds to each coordinate - other coordinates knn_D, knn_I = torch.topk(norm, k=K, dim=2, largest=False, sorted=True) # BxMxK distance and index of the first k closest distances from each center to other centers
Then, for each center, calculate the coordinates, mean and decentralized coordinate values of k nearest neighbor centers:
neighbors = operations.knn_gather_wrapper(coordinate_tensor, knn_I) # Bx3xMxK coordinates of the nearest k centers of each center if center_type == 'avg': # If k average values are taken as the center neighbors_center = torch.mean(neighbors, dim=3, keepdim=True) # Bx3xMx1 coordinate center of the nearest k centers of each center elif center_type == 'center': # Center itself as the center neighbors_center = coordinate_tensor.unsqueeze(3) # Bx3xMx1 coordinates of each center neighbors_decentered = (neighbors - neighbors_center).detach() # Bx3xMxK de center coordinates of the nearest k centers of each center neighbors_center = neighbors_center.squeeze(3).detach() # Bx3xM center coordinates
Finally, the eigenvectors of k centers nearest to each center are obtained and used as the input of the volume layer, which is defined in the same file. The return values are the coordinates and eigenvectors of the center point.
x_neighbors = operations.knn_gather_by_indexing(x, knn_I) # BxCxMxK eigenvectors of the nearest k centers of each center x_augmented = torch.cat((neighbors_decentered, x_neighbors), dim=1) # Bx(3+C)xMxK and central coordinate splicing
2.1.4 final_pointnet
this part of the code is in. / models/layers.py. Is a conventional pointnet network structure. Get the global feature vector (global feature in the figure).
2.2 classifier
this part of the code is located in models/networks.py. Its essence is a three-layer full connection layer, which outputs the probability of each category of point cloud. Finally, it is trained with cross entropy loss.
3, Testing and preservation
3.1 get_current_errors & visualizer.plot_current_errors
this part of the code is in. / models/classifier.py. Statistical prediction accuracy and visualization (loss time curve).
3.2 model.save_network
save the model to the specified path.
def save_network(self, network, network_label, epoch_label, gpu_id): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.opt.checkpoints_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if gpu_id>=0 and torch.cuda.is_available(): # torch.cuda.device(gpu_id) network.to(self.opt.device)
3.3 model.update_learning_rate
update learning rate.
def update_learning_rate(self, ratio): lr_clip = 0.00001 # encoder lr_encoder = self.old_lr_encoder * ratio if lr_encoder < lr_clip: lr_encoder = lr_clip for param_group in self.optimizer_encoder.param_groups: param_group['lr'] = lr_encoder print('update encoder learning rate: %f -> %f' % (self.old_lr_encoder, lr_encoder)) self.old_lr_encoder = lr_encoder # classifier lr_classifier = self.old_lr_classifier * ratio if lr_classifier < lr_clip: lr_classifier = lr_clip for param_group in self.optimizer_classifier.param_groups: param_group['lr'] = lr_classifier print('update classifier learning rate: %f -> %f' % (self.old_lr_classifier, lr_classifier)) self.old_lr_classifier = lr_classifier