Path: blob/master/notebooks/book1/23/gnn_graph_classification_jraph.ipynb
1192 views
Graph Classification on MUTAG (Molecules)
Author: Nimish Sanghi https://github.com/nsanghi
In this notebook we will use JAX, Haiku, Optax and Jraph
You can read through Node Classification notebook for the details each of the above libraries and how they work together to train GNN models.
We will continue further in this notebook and focus on a complete Graph Classification model for MUTAG.
This notebook is based on: https://github.com/deepmind/educational/blob/master/colabs/summer_schools/intro_to_graph_nets_tutorial_with_jraph.ipynb
The main difference from our previous setup is that instead of observing individual node latents, we are now attempting to summarize them into one embedding vector, representative of the entire graph, which we then use to predict the class of this graph.
We will do this on one of the most common tasks of this type -- molecular property prediction, where molecules are represented as graphs. Nodes correspond to atoms, and edges represent the bonds between them.
We will use the MUTAG dataset for this example, a common dataset from the TUDatasets collection.
Authers of the notebook at Deepmind have converted this dataset to be compatible with jraph and we will download it in the cells below.
Citation for TUDatasets: Morris, Christopher, et al. Tudataset: A collection of benchmark datasets for learning with graphs. arXiv preprint arXiv:2007.08663. 2020.
Setup: Install and Import libraries
Building wheel for jraph (setup.py) ... done
|████████████████████████████████| 309 kB 8.2 MB/s
|████████████████████████████████| 140 kB 8.1 MB/s
|████████████████████████████████| 72 kB 735 kB/s
Representing a Graph in jraph
In jraph, a graph is represented with a GraphsTuple
object. In addition to defining the graph structure of nodes and edges, you can also store node features, edge features and global graph features in a GraphsTuple
.
In the GraphsTuple
, edges are represented in two aligned arrays of node indices: senders (source nodes) and receivers (destinaton nodes). Each index corresponds to one edge, e.g. edge i
goes from senders[i]
to receivers[i]
.
Let us now download the jraph version of MUTAG
--2022-06-30 11:25:01-- https://storage.googleapis.com/dm-educational/assets/graph-nets/jraph_datasets/mutag.pickle
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.99.128, 173.194.202.128, 173.194.203.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.99.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 352981 (345K) [application/octet-stream]
Saving to: ‘/tmp/mutag.pickle’
mutag.pickle 100%[===================>] 344.71K --.-KB/s in 0.002s
2022-06-30 11:25:01 (145 MB/s) - ‘/tmp/mutag.pickle’ saved [352981/352981]
Visualizing the Graph
To visualize the graph structure, we will use the networkx
library because it already has functions for drawing graphs.
We first convert the jraph.GraphsTuple
to a networkx.DiGraph
.
The dataset is saved as a list of examples, each example is a dictionary containing an input_graph and its corresponding target.
Let us see first graph:
We see that there are 188 graphs, to be classified in one of 2 classes, representing "their mutagenic effect on a specific gram negative bacterium". Node features represent the 1-hot encoding of the atom type (0=C, 1=N, 2=O, 3=F, 4=I, 5=Cl, 6=Br). Edge features (edge_attr
) represent the bond type, which we will here ignore.
Let's split the dataset to use the first 150 graphs as the training set (and the rest as the test set).
Padding Graphs to Speed Up Training
Since jax recompiles the program for each graph size, training would take a long time due to recompilation for different graph sizes. To address that, we pad the number of nodes and edges in the graphs to nearest power of two. Since jax maintains a cache of compiled programs, the compilation cost is amortized.
Graph Network Model Definition
We will use jraph.GraphNetwork()
to build our graph model. The GraphNetwork
architecture is defined in Battaglia et al. (2018). This function requires following parameters:
We first define update functions for nodes, edges, and the full graph (global). We will use MLP blocks for all three.
Loss and Accuracy Function
Define the classification cross-entropy loss and accuracy function.