one Introduction to message passing
The extension of convolution operator to irregular domain is usually expressed as a neighborhood aggregation or message passing ) programme
Given the characteristics of layer (k-1) points, and possible features of edges between points, GNN relying on information transmission can be described as:
amongRepresents a differentiable, differentiable, permutation invariant function (such as sum, mean, or max), γ and Φ Represents a differential equation (such as MLP)
2 message passing class
PyG provides message Passing base class, which helps to create this kind of message passing graph neural network by automatically processing message propagation.
Users only need to define γ (update function) and Φ (message function), and the aggregation method aggr (i.e)[aggr="add", aggr="mean" or aggr="max"]
MessagePassing( aggr="add", flow="source_to_target", node_dim=-2)
Defines the aggregation method (here 'add')
Flow direction of information transfer ("source_to_target") [default] or "target_to_source")
node_dim indicates which axis to pass along
MessagePassing.propagate( edge_index, size=None, **kwargs)
The initial call to start propagating the message.
Get the edge index and all additional data required to construct messages and update node embedding.
propagate() can deliver messages not only in the adjacent square matrix of [N,N], but also in the non square matrix (for example, bipartite graph [N,M], at this time, set size=(N,M) as an additional formal parameter)
If the size parameter is set to None, the matrix defaults to a square matrix.
For bipartite graph [N,M], it has two sets of independent point sets. We also need to set x=(x_N,x_M)
be similar to Φ. Pass information to node i. If flow="source_to_target", find all (j,i) ∈ E; if flow="target_to_source", find all (i,j) belonging to E.
Any parameter originally passed to propagate() can be accepted.
In addition, the tensors passed to propagate() can be mapped to their respective nodes by appending _ior _jto the variable name. For example, x_i (representing the central node) x_j (represents neighbor node).
Note that we usually call i the central node for gathering information and j the adjacent node, because this is the most common representation.
2.4 MessagePassing.update(aggr_out, ...)
analogy γ， For each point i ∈ V, update its node embedding
The first parameter is the aggregate output, with all parameters passed to propagate() as subsequent parameters
3 example: GCN
3.1 GCN review
The GCN layer can be expressed as:
The neighbor nodes of k-1 layer pass through the weight matrix first Θ Weighted, then normalized by the degrees of the central node and the neighbor node, and finally summed and aggregated.
3.2 implementation process of message passing
This equation can be divided into the following steps
- Add a self ring to the adjacency matrix (because the above formula Σ In the subscript of, in addition to i's neighbors, i itself)
- Linear transformation characteristic matrix
- Calculate normalization coefficient
- Normalized neighbor / upper layer's own point features( Φ， message operation)
- Summing neighbor nodes and their own point characteristics ("add", γ Operation)
Steps 1 to 3 have been calculated before the start of message passing; steps 4 and 5 can be processed by MessagePassing.
3.3 code analysis
import torch from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degree class GCNConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr='add') # "Add" aggregation (Step 5). #The aggregation method inherited by GCN class from messagepaging: "add" self.lin = torch.nn.Linear(in_channels, out_channels) def forward(self, x, edge_index): # x has shape [N, in_channels] - N points, in_channels dimension attribute of each point # edge_index has shape [2, E] - E edges. Each edge has an out edge and an in edge # Step 1: Add self-loops to the adjacency matrix. edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) #Add self ring # Step 2: Linearly transform node feature matrix. x = self.lin(x) #Change X linearly # Step 3: Compute normalization. row, col = edge_index #Out side and in side deg = degree(col, x.size(0), dtype=x.dtype) #In degree of each point (undirected graph, so the in and out degrees are the same) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] #1/sqrt(di) *1/sqrt(dj) # Step 4-5: Start propagating messages. return self.propagate(edge_index, x=x, norm=norm) #propagate #message(),aggregate() and update() are called internally in propagate #As additional parameters of message propagation, we pass node embedding x and normalization coefficient norm. def message(self, x_j, norm): # x_j has shape [E, out_channels] #We need to normalize the adjacent node features x_j by norm #Here x_j is a tensor, which contains the source node characteristics of each edge, that is, the neighbors of each node. # Step 4: Normalize node features. return norm.view(-1, 1) * x_j #1/sqrt(di) *1/sqrt(dj) *X_j
After that, we can easily call it with this method:
conv = GCNConv(16, 32) x = conv(x, edge_index)