The Math Behind GNN
The core of a Graph Neural Network (GNN) lies in its message-passing mechanism, which allows nodes to iteratively update their representations by aggregating information from their neighbors.
A general formula for the update rule of a node $v$ at layer $l+1$ can be expressed as:
$$ h_v^{(l+1)} = \sigma \left( W^{(l)} \sum_{u \in \mathcal{N}(v) \cup v} \frac{h_u^{(l)}}{| \mathcal{N}(v)|} \right) $$
Let’s break down this formula:
$h_v^{(l+1)}$: This is the new feature vector (or embedding) of our target node $v$ at the next layer, $l+1$. This is what we want to compute.
$\sigma$: This represents a non-linear activation function, such as ReLU or Tanh. Applying this function helps the model learn more complex patterns in the data.
$W^{(l)}$: This is a learnable weight matrix for layer $l$. This matrix is shared across all nodes and is updated during the training process to optimize the GNN’s performance.
$\sum_{u \in \mathcal{N}(v) \cup v}$: This is the aggregation step. We are summing up the feature vectors of all nodes $u$ that are in the neighborhood of $v$ (denoted by $\mathcal{N}(v)$), including the node $v$ itself.
$h_u^{(l)}$: This is the feature vector of a neighboring node $u$ from the previous layer, $l$.
$| \mathcal{N}(v)|$: This is the number of neighbors of node $v$, which is used to normalize the aggregated sum. This normalization prevents nodes with a large number of neighbors from having disproportionately large feature vectors.
In simpler terms, for each node, the GNN:
- Gathers the feature vectors from all its neighbors.
- Aggregates them (in this case, by taking a normalized sum).
- Transforms the aggregated vector using a learned weight matrix.
- Applies a non-linear activation to get the node’s new feature vector for the next layer.
This process is repeated for a fixed number of layers, allowing information to propagate across the graph.
Pytorch implementation
import torch
import torch_geometry