The Math Behind GNN

The core of a Graph Neural Network (GNN) lies in its message-passing mechanism, which allows nodes to iteratively update their representations by aggregating information from their neighbors.

A general formula for the update rule of a node $v$ at layer $l+1$ can be expressed as:

$$ h_v^{(l+1)} = \sigma \left( W^{(l)} \sum_{u \in \mathcal{N}(v) \cup v} \frac{h_u^{(l)}}{| \mathcal{N}(v)|} \right) $$

Let’s break down this formula:

  • $h_v^{(l+1)}$: This is the new feature vector (or embedding) of our target node $v$ at the next layer, $l+1$. This is what we want to compute.

  • $\sigma$: This represents a non-linear activation function, such as ReLU or Tanh. Applying this function helps the model learn more complex patterns in the data.

  • $W^{(l)}$: This is a learnable weight matrix for layer $l$. This matrix is shared across all nodes and is updated during the training process to optimize the GNN’s performance.

  • $\sum_{u \in \mathcal{N}(v) \cup v}$: This is the aggregation step. We are summing up the feature vectors of all nodes $u$ that are in the neighborhood of $v$ (denoted by $\mathcal{N}(v)$), including the node $v$ itself.

  • $h_u^{(l)}$: This is the feature vector of a neighboring node $u$ from the previous layer, $l$.

  • $| \mathcal{N}(v)|$: This is the number of neighbors of node $v$, which is used to normalize the aggregated sum. This normalization prevents nodes with a large number of neighbors from having disproportionately large feature vectors.

In simpler terms, for each node, the GNN:

  1. Gathers the feature vectors from all its neighbors.
  2. Aggregates them (in this case, by taking a normalized sum).
  3. Transforms the aggregated vector using a learned weight matrix.
  4. Applies a non-linear activation to get the node’s new feature vector for the next layer.

This process is repeated for a fixed number of layers, allowing information to propagate across the graph.

Pytorch implementation

import torch
import torch_geometry