Path: blob/master/deep_learning/gnn/gnn_node_classification_intro.ipynb
2588 views
Graph Neural Networks Node Classification Quick Introduction
In this particular notebook, we'll be:
Giving a quick introduction to Graph Neural Network.
Implement one particular algorithm GraphSAGE using DGL, deep graph library. We'll be introducing how to work with DGL library on the graph node classification task.
Use Pytorch Lightning for organizing our building blocks and training our model. This isn't a PyTorch lightning tutorial, readers are expected to understand its concept such as
LightningDataModule,LightningModuleas well asTrainer.
A graph is a data structure containing a set of nodes (a.k.a vertices) and a set of edges connecting vertices and . Each node has an associated node features and labels
A single Graph neural network (GNN) layer has three main steps that are performed on every node in the graph.
Message Passing
Aggregation
Update
Message Passing:
GNN learns a node by examining nodes in its neighborhood , where is defined as the set of nodes connected to the source node by an edge, more formally, . When examining a node, we can use any arbitrary function, , either a neural network like, MLP, or here, let's assume it will be a affine transformation:
Here, represents matrix multiplication.
Aggregation:
Now that we have the messages from neighborhood node, we have to aggregate them somehow. Popular aggregation function includes sum/mean/max/min. The aggregation function, , can be denoted as:
Update:
The GNN layer now has to update our source node 's features and combine it with the incoming aggregated messages. For example, using addition.
Here, , denotes a function that's applied to the source node 's feature, denotes another transformation to project the blended features into another dimension, denotes an activation function such as Relu. Notation-wise, the initial node features are called , after a forward pass through a GNN layer, we denote the node features as . If we were to have multiple GNN layers, then we denote node features as , where is the current GNN layer index.
Note:
GNN aims to learn a function that generates embeddings via sampling and aggregation features from nodes' neighborhood. Innovations in GNN mainly involves changings to these three steps.
Number of layers in GNN is a hyperparameter that can be tweaked, the intuition is that GNN layer aggregate features from the hop neighborhood of node . i.e. initially, the node sees its immediate neighbors and deeper into the network, it interacts with neighbors' neighbors and so on. Most GNN papers uses less than 4 layers to prevent the network from dying, where node embeddings all converge to similar representation after seeing nodes many hops away. This phenomenon becomes more prevalent for small and sparse graphs.
Implementation
This code is largely ported from DGL's Pytorch lightning node classification example with additional explanations in between each section.
We will be using the ogbn-products dataset. Directly copying this dataset's description from its description page.
ogbn-products dataset is an undirected and unweighted graph, representing an Amazon product co-purchasing network. Nodes represent products sold in Amazon, and edges between two products indicate that the products are purchased together. Node features are generated by extracting bag-of-words features from the product descriptions followed by a Principal Component Analysis to reduce the dimension to 100.
The task is to predict the category of a product in a multi-class classification setup, where the 47 top-level categories are used for target labels.
For DGL graph, we can assign or extract node's features via our graph's ndata attribute. Here, we assign dataset's label to our graph.
Data Module
Given a graph as well as data splits, we can get our hands dirty and implement our data module.
Similar to general neural networks, we need a DataLoader to sample batches of inputs. A data loader by default returns 3 elements: input_nodes, output_nodes, blocks.
input_nodes describe the nodes needed to compute the representation of output_nodes. Whereas blocks describe for each GNN layer, which node representations are to be computed as output, which node representations are needed as input, and how does representation from the input nodes propagate to the output nodes.
Each data loader also accepts an sampler, here we are using one called NeighborSampler, which will make every node gather mesages from a fixed number of neighbors. We get to define the fanouts parameter for the sampler which represents number of neighbors to sample for each GNN layer.
It also supports PyTorch concepts such as prefetching so model computation and data movement can happen in parallel, as well as a concept called UVA (unified virtual addressing), directly quoting from its documentation: This is when our graph is too large to fit onto GPU memory, and we let GPU perform sampling on graph that will be pinned on CPU memory.
GraphSAGE Model
GraphSAGE stands for Graph SAmple and AggreGatE, its forward pass can be described with the following notation:
Hopefully each of these steps won't look that alien after covering the general pattern of GNN.
For each node, it aggregates feature representation from its immediate neighborhood, which can be uniformly sampled. The original paper uses aggregation function such as mean, pooling, LSTM.
After aggregating neighboring feature representations, it then concatenates it with the node's current representation. This concatenation is then fed through a fully connected layer with nonlinear activation function.
The last step is normalizing learned embedding to unit length.
The way this works in DGL is if our features are stored in a graph object's ndata, then from a sampled block object we can access source nodes' feature via srcdata and destination nodes' feature via dstdata. In the next few code chunks, we first perform a small demo where we access source node's features, feed it through a GNN layer, and check whether its shape matches output nodes' label size. After that we'll proceed with implementing our main model/module.
Trainer
The node chunk initiatizes the data as well as model module and kicks off model training through Trainer class .
Evaluation
The prediction/evaluation is also a bit interesting. As explained clearly by DGL Tutorial - Exact Offline Inference on Large Graphs While training our GNN, we often times perform neighborhood sampling for reducing memory. But while performing inferencing, it's better to truly aggregate over all neighbors.
The result of this is that our inference implemention will be slightly different compared to training. During training, we have an outer loop that's iterating over mini-batches of nodes (this is coming from our DataLoader), and an inner loop that's iterating over all our GNN's layer. During inferencing, what will happen is, we instead will have an outer loop that's iterating over the GNN layers, and an inner loop that's iterating over our mini-batches of nodes.
Hopefully, this served as a quick introduction to GNN's node classification task. Feel free to check the leaderboard for potential improvements to this baseline approach.