Path: blob/master/notebooks/book1/23/gnn_node_classification_jraph.ipynb
1192 views
GNN Node Classification
Author: Nimish Sanghi https://github.com/nsanghi
In this notebook we will use JAX, Haiku, Optax and Jraph
JAX is a numerical computing library that combines NumPy, automatic differentiation, and first-class GPU/TPU support.
Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX's pure function transformations.
Optax is a gradient processing and optimization library for JAX.
Jraph (pronounced “giraffe”) is a lightweight library for working with graph neural networks in jax. It provides a data structure for graphs, a set of utilites for working with graphs, and a ‘zoo’ of forkable graph neural network models.
We will use Graph Convolution layer jraph.GraphConvolution on a simple grpah to classify the nodes of the grpahs. We also see how accuracy of graph can be improved by replacing Convolution layer with Graph Attention Network layer jraph.GAT
This notebook is based on: https://github.com/deepmind/educational/blob/master/colabs/summer_schools/intro_to_graph_nets_tutorial_with_jraph.ipynb
Setup: Install and Import libraries
Building wheel for jraph (setup.py) ... done
|████████████████████████████████| 309 kB 4.8 MB/s
|████████████████████████████████| 140 kB 4.9 MB/s
|████████████████████████████████| 72 kB 812 kB/s
Representing a Graph in jraph
In jraph, a graph is represented with a GraphsTuple
object. In addition to defining the graph structure of nodes and edges, you can also store node features, edge features and global graph features in a GraphsTuple
.
In the GraphsTuple
, edges are represented in two aligned arrays of node indices: senders (source nodes) and receivers (destinaton nodes). Each index corresponds to one edge, e.g. edge i
goes from senders[i]
to receivers[i]
.
Visualizing the Graph
To visualize the graph structure, we will use the networkx
library because it already has functions for drawing graphs.
We first convert the jraph.GraphsTuple
to a networkx.DiGraph
.
Zachary's Karate Club Dataset
Zachary's karate club is a small dataset commonly used as an example for a social graph. It's great for demo purposes, as it's easy to visualize and quick to train a model on it.
A node represents a student or instructor in the club. An edge means that those two people have interacted outside of the class. There are two instructors in the club.
Each student is assigned to one of two instructors.
Optimizing the GCN on the Karate Club Node Classification Task
The task is to predict the assignment of students to instructors, given the social graph and only knowing the assignment of two nodes (the two instructors) a priori.
In other words, out of the 34 nodes, only two nodes are labeled, and we are trying to optimize the assignment of the other 32 nodes, by maximizing the log-likelihood of the two known node assignments.
We will compute the accuracy of our node assignments by comparing to the ground-truth assignments. Note that the ground-truth for the 32 student nodes is not used in the loss function itself.
Visualize the karate club graph with circular node layout:
Node Classification with GCN
Define the GCN with the jraph.GraphConvolution
layer. We will use two convolution layers.
jraph.GraphConvolution
requires following parameters:
update_node_fn
- function used to update the nodes. We will use a single layer MLP with ReLUaggregate_nodes_fn
- function used to aggregates the sender nodes. The default is jax.ops.segment_sum, i.e. to sum the node features of the neighbors. We will use the default value.add_self_edges
- whether to add self edges to nodes in the graph. Defaults toFalse
. We will set this value ofTrue
.symmetric_normalization
- whether to use symmetric normalization. Defaults to True.
Training and Evaluation code:
Let's train the GCN and check the accuracy. We expect this model to reach an accuracy of about 0.91.
Visualize ground truth and predicted node assignments:
Node Classification with Graph Attention (GAT) layer
While the GCN can learn meaningful representations, it also has some shortcomings.
In the GCN layer, the messages from all its neighbours and the node itself are equally weighted. This may lead to loss of node-specific information. E.g., consider the case when a set of nodes shares the same set of neighbors, and start out with different node features. Then because of averaging, their resulting output features would be the same. Adding self-edges mitigates this issue by a small amount, but this problem is magnified with increasing number of GCN layers and number of edges connecting to a node.
The graph attention (GAT) mechanism, as proposed by Velickovic et al. ( 2017), allows the network to learn how to weigh / assign importance to the node features from the neighbourhood when computing the new node features. This is very similar to the idea of using attention in Transformers, which were introduced in Vaswani et al. (2017).
(One could even argue that Transformers are graph attention networks operating on the special case of fully-connected graphs.)
In the figure below, are the node features and are the learned attention weights.
Figure Credit: Velickovic et al. ( 2017). (Detail: This image is showing multi-headed attention with 3 heads, each color corresponding to a different head. At the end, an aggregation function is applied over all the heads.)
To obtain the output node features of the GAT layer, we compute:
Here, is a weight matrix which performs a linear transformation on the input.
We will use jraph.GAT(attention_query_fn, attention_logit_fn, node_update_fn=None)
to build the Graph Attention Network.
attention_query_fn
- function that generates attention queries from sender node features.
attention_logit_fn
- function that converts attention queries into logits for softmax attention.
node_update_fn
- function that updates the aggregated messages. If None, will apply leaky relu and concatenate (if using multi-head attention).
jraph.GAT
assumes that graph as self-edges as part of the grpah provided to GAT. We need to add self-edges
ourselves
Let's train the model. We expect the model to reach an accuracy of about 0.97.
Visualize ground truth and predicted node assignments: