Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/deep_learning/gnn/gnn_node_classification_intro.ipynb
2588 views
Kernel: Python 3 (ipykernel)
# code for loading the format for the notebook import os # path : store the current path to convert back to it later path = os.getcwd() os.chdir(os.path.join('..', '..', 'notebook_format')) from formats import load_style load_style(css_style='custom2.css', plot_style=False)
os.chdir(path) # 1. magic to print version # 2. magic so that the notebook will reload external python modules %load_ext watermark %load_ext autoreload %autoreload 2 import dgl import torch import dgl.data import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchmetrics.functional as MF from time import perf_counter from torchmetrics import Accuracy from ogb.nodeproppred import DglNodePropPredDataset from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint device = torch.device("cuda" if torch.cuda.is_available() else "cpu") %watermark -a "Ethen" -d -u -v -iv
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Author: Ethen Last updated: 2022-10-22 Python implementation: CPython Python version : 3.8.10 IPython version : 8.4.0 torch : 1.10.0a0+git36449ea dgl : 0.9.1 torchmetrics: 0.10.0

Graph Neural Networks Node Classification Quick Introduction

In this particular notebook, we'll be:

  • Giving a quick introduction to Graph Neural Network.

  • Implement one particular algorithm GraphSAGE using DGL, deep graph library. We'll be introducing how to work with DGL library on the graph node classification task.

  • Use Pytorch Lightning for organizing our building blocks and training our model. This isn't a PyTorch lightning tutorial, readers are expected to understand its concept such as LightningDataModule, LightningModule as well as Trainer.

A graph G(V,E)\mathcal{G}(V, E) is a data structure containing a set of nodes (a.k.a vertices) i∈Vi \in V and a set of edges eij∈Ee_{ij} \in E connecting vertices ii and jj. Each node ii has an associated node features xi∈Rdx_i \in \mathbb{R}^d and labels yiy_i

A single Graph neural network (GNN) layer has three main steps that are performed on every node in the graph.

  1. Message Passing

  2. Aggregation

  3. Update

Message Passing:

GNN learns a node ii by examining nodes in its neighborhood NiN_i, where NiN_i is defined as the set of nodes jj connected to the source node ii by an edge, more formally, Ni=j:eij∈EN_i = {j : e_{ij} \in E}. When examining a node, we can use any arbitrary function, FF, either a neural network like, MLP, or here, let's assume it will be a affine transformation:

F(xj)=Wjâ‹…xj+b\begin{align} \begin{aligned} F(x_j) = \mathbf{W}_j \cdot x_j + b \end{aligned} \end{align}

Here, â‹…\cdot represents matrix multiplication.

Aggregation:

Now that we have the messages from neighborhood node, we have to aggregate them somehow. Popular aggregation function includes sum/mean/max/min. The aggregation function, GG, can be denoted as:

mˉi=G({F(xj):j∈Ni})\begin{align} \begin{aligned} \bar{m}_i = G(\{F(x_j) : j \in \mathcal{N}_i\}) \end{aligned} \end{align}

Update:

The GNN layer now has to update our source node ii's features and combine it with the incoming aggregated messages. For example, using addition.

hi=σ(K(T(xi)+mˉi)))\begin{align} \begin{aligned} h_i = \sigma(K(T(x_i) + \bar{m}_i))) \end{aligned} \end{align}

Here, TT, denotes a function that's applied to the source node ii's feature, KK denotes another transformation to project the blended features into another dimension, σ\sigma denotes an activation function such as Relu. Notation-wise, the initial node features are called xix_i, after a forward pass through a GNN layer, we denote the node features as hih_i. If we were to have multiple GNN layers, then we denote node features as hilh_i^l, where ll is the current GNN layer index.

Note:

  • GNN aims to learn a function that generates embeddings via sampling and aggregation features from nodes' neighborhood. Innovations in GNN mainly involves changings to these three steps.

  • Number of layers in GNN is a hyperparameter that can be tweaked, the intuition is that lthl^{th} GNN layer aggregate features from the lthl^{th} hop neighborhood of node ii. i.e. initially, the node sees its immediate neighbors and deeper into the network, it interacts with neighbors' neighbors and so on. Most GNN papers uses less than 4 layers to prevent the network from dying, where node embeddings all converge to similar representation after seeing nodes many hops away. This phenomenon becomes more prevalent for small and sparse graphs.

Implementation

This code is largely ported from DGL's Pytorch lightning node classification example with additional explanations in between each section.

We will be using the ogbn-products dataset. Directly copying this dataset's description from its description page.

ogbn-products dataset is an undirected and unweighted graph, representing an Amazon product co-purchasing network. Nodes represent products sold in Amazon, and edges between two products indicate that the products are purchased together. Node features are generated by extracting bag-of-words features from the product descriptions followed by a Principal Component Analysis to reduce the dimension to 100.

The task is to predict the category of a product in a multi-class classification setup, where the 47 top-level categories are used for target labels.

dataset = DglNodePropPredDataset("ogbn-products") graph, labels = dataset[0] graph
Graph(num_nodes=2449029, num_edges=123718280, ndata_schemes={'feat': Scheme(shape=(100,), dtype=torch.float32)} edata_schemes={})

For DGL graph, we can assign or extract node's features via our graph's ndata attribute. Here, we assign dataset's label to our graph.

graph.ndata["label"] = labels.squeeze() graph.ndata
{'feat': tensor([[ 0.0319, -0.1959, 0.0520, ..., 0.0767, -0.3930, -0.0648], [-0.0241, 0.6303, 1.0606, ..., -1.6875, 3.5867, 0.8182], [ 0.3327, -0.5586, -0.2886, ..., -0.3716, 0.2521, 0.0415], ..., [ 0.1066, 0.2655, -0.0057, ..., 1.0867, 0.0759, -1.1737], [ 0.2497, -0.2574, 0.4123, ..., 1.5466, 1.0310, -0.2966], [ 0.7175, -0.2393, 0.0443, ..., -1.0132, -0.4141, -0.0823]]), 'label': tensor([0, 1, 2, ..., 8, 2, 4])}
# extract the train, validation and test split provided via the dataset split_idx = dataset.get_idx_split() train_idx, val_idx, test_idx = ( split_idx["train"], split_idx["valid"], split_idx["test"], )

Data Module

Given a graph as well as data splits, we can get our hands dirty and implement our data module.

class DataModule(LightningDataModule): def __init__( self, graph, train_idx, val_idx, fanouts, batch_size, n_classes, device ): super().__init__() sampler = dgl.dataloading.NeighborSampler( fanouts, prefetch_node_feats=["feat"], prefetch_labels=["label"] ) self.graph = graph self.train_idx = train_idx self.val_idx = val_idx self.sampler = sampler self.batch_size = batch_size self.in_feats = graph.ndata["feat"].shape[1] self.n_classes = n_classes def train_dataloader(self): return dgl.dataloading.DataLoader( self.graph, self.train_idx.to(device), self.sampler, device=device, batch_size=self.batch_size, shuffle=True, drop_last=False, num_workers=0, use_uva=True ) def val_dataloader(self): return dgl.dataloading.DataLoader( self.graph, self.val_idx.to(device), self.sampler, device=device, batch_size=self.batch_size, shuffle=True, drop_last=False, num_workers=0, use_uva=True, )

Similar to general neural networks, we need a DataLoader to sample batches of inputs. A data loader by default returns 3 elements: input_nodes, output_nodes, blocks.

input_nodes describe the nodes needed to compute the representation of output_nodes. Whereas blocks describe for each GNN layer, which node representations are to be computed as output, which node representations are needed as input, and how does representation from the input nodes propagate to the output nodes.

Each data loader also accepts an sampler, here we are using one called NeighborSampler, which will make every node gather mesages from a fixed number of neighbors. We get to define the fanouts parameter for the sampler which represents number of neighbors to sample for each GNN layer.

It also supports PyTorch concepts such as prefetching so model computation and data movement can happen in parallel, as well as a concept called UVA (unified virtual addressing), directly quoting from its documentation: This is when our graph is too large to fit onto GPU memory, and we let GPU perform sampling on graph that will be pinned on CPU memory.

fanouts = [15, 10, 5] batch_size = 2 data_module = DataModule(graph, train_idx, val_idx, fanouts, batch_size, dataset.num_classes, device) # sample output from the data loader input_nodes, output_nodes, blocks = next(iter(data_module.train_dataloader())) input_nodes, output_nodes, blocks
(tensor([106546, 74635, 187662, ..., 232149, 215355, 80118], device='cuda:0'), tensor([106546, 74635], device='cuda:0'), [Block(num_src_nodes=1428, num_dst_nodes=126, num_edges=1847), Block(num_src_nodes=126, num_dst_nodes=12, num_edges=120), Block(num_src_nodes=12, num_dst_nodes=2, num_edges=10)])

GraphSAGE Model

GraphSAGE stands for Graph SAmple and AggreGatE, its forward pass can be described with the following notation:

hNi(l+1)=aggregate(l+1)({hjl,∀j∈Ni})hi(l+1)=σ(W(l+1)⋅concat(hil,hNil+1))hi(l+1)=norm(hi(l+1))\begin{align} \begin{aligned} h_{N_i}^{(l+1)} &= \mathrm{aggregate}^{(l+1)} \left(\{h_{j}^{l}, \forall j \in N_i \}\right)\\h_{i}^{(l+1)} &= \sigma \left(W^{(l+1)} \cdot \mathrm{concat} (h_{i}^{l}, h_{N_i}^{l+1}) \right)\\h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{(l+1)}) \end{aligned} \end{align}

Hopefully each of these steps won't look that alien after covering the general pattern of GNN.

  • For each node, it aggregates feature representation from its immediate neighborhood, which can be uniformly sampled. The original paper uses aggregation function such as mean, pooling, LSTM.

  • After aggregating neighboring feature representations, it then concatenates it with the node's current representation. This concatenation is then fed through a fully connected layer with nonlinear activation function.

  • The last step is normalizing learned embedding to unit length.

The way this works in DGL is if our features are stored in a graph object's ndata, then from a sampled block object we can access source nodes' feature via srcdata and destination nodes' feature via dstdata. In the next few code chunks, we first perform a small demo where we access source node's features, feed it through a GNN layer, and check whether its shape matches output nodes' label size. After that we'll proceed with implementing our main model/module.

blocks[0].srcdata["feat"].shape
torch.Size([1428, 100])
# out_feats is configurable, analogous to hidden layer's dimension size sage_conv = dgl.nn.SAGEConv( in_feats=data_module.in_feats, out_feats=256, aggregator_type="mean" ).to(device) output = sage_conv(blocks[0], blocks[0].srcdata["feat"]) output.shape
torch.Size([126, 256])
blocks[0].dstdata["label"].shape
torch.Size([126])
class SAGE(LightningModule): """Multi-layer GraphSAGE lightning module for node classification task.""" def __init__(self, in_feats: int, n_layers: int, n_hidden: int, n_classes: int, aggregator_type: str): super().__init__() self.save_hyperparameters() self.layers = nn.ModuleList() self.layers.append(dgl.nn.SAGEConv(in_feats, n_hidden, aggregator_type)) for i in range(1, n_layers - 1): self.layers.append(dgl.nn.SAGEConv(n_hidden, n_hidden, aggregator_type)) self.layers.append(dgl.nn.SAGEConv(n_hidden, n_classes, aggregator_type)) self.dropout = nn.Dropout(0.5) self.n_hidden = n_hidden self.n_classes = n_classes self.train_acc = Accuracy() self.val_acc = Accuracy() def forward(self, blocks, x): h = x for l, (layer, block) in enumerate(zip(self.layers, blocks)): h = layer(block, h) if l != len(self.layers) - 1: h = F.relu(h) h = self.dropout(h) return h def training_step(self, batch, batch_idx): input_nodes, output_nodes, blocks = batch x = blocks[0].srcdata["feat"] y = blocks[-1].dstdata["label"] y_hat = self(blocks, x) loss = F.cross_entropy(y_hat, y) self.train_acc(torch.argmax(y_hat, 1), y) self.log( "train_acc", self.train_acc, prog_bar=True, on_step=True, on_epoch=False ) return loss def validation_step(self, batch, batch_idx): input_nodes, output_nodes, blocks = batch x = blocks[0].srcdata["feat"] y = blocks[-1].dstdata["label"] y_hat = self(blocks, x) self.val_acc(torch.argmax(y_hat, 1), y) self.log( "val_acc", self.val_acc, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True ) def configure_optimizers(self): optimizer = torch.optim.Adam( self.parameters(), lr=0.001, weight_decay=5e-4 ) return optimizer
model = SAGE( in_feats=data_module.in_feats, n_layers=len(fanouts), n_hidden=256, n_classes=data_module.n_classes, aggregator_type="mean" ) model
SAGE( (layers): ModuleList( (0): SAGEConv( (feat_drop): Dropout(p=0.0, inplace=False) (fc_self): Linear(in_features=100, out_features=256, bias=False) (fc_neigh): Linear(in_features=100, out_features=256, bias=False) ) (1): SAGEConv( (feat_drop): Dropout(p=0.0, inplace=False) (fc_self): Linear(in_features=256, out_features=256, bias=False) (fc_neigh): Linear(in_features=256, out_features=256, bias=False) ) (2): SAGEConv( (feat_drop): Dropout(p=0.0, inplace=False) (fc_self): Linear(in_features=256, out_features=47, bias=False) (fc_neigh): Linear(in_features=256, out_features=47, bias=False) ) ) (dropout): Dropout(p=0.5, inplace=False) (train_acc): Accuracy() (val_acc): Accuracy() )

Trainer

The node chunk initiatizes the data as well as model module and kicks off model training through Trainer class .

n_hidden = 256 fanouts = [15, 10, 5] aggregator_type = "mean" batch_size = 1024 n_layers = len(fanouts) data_module = DataModule(graph, train_idx, val_idx, fanouts, batch_size, dataset.num_classes, device) model = SAGE( in_feats=data_module.in_feats, n_layers=n_layers, n_hidden=n_hidden, n_classes=data_module.n_classes, aggregator_type=aggregator_type ) checkpoint_callback = ModelCheckpoint(monitor="val_acc", save_top_k=1) trainer = Trainer( accelerator='gpu', devices=[0], max_epochs=10, # note, we purpose-fully disabled the progress bar to prevent flooding our notebook's console # in normal settings, we can/should definitely turn it on enable_progress_bar=False, log_every_n_steps=100, callbacks=[checkpoint_callback] ) t1_start = perf_counter() trainer.fit(model, datamodule=data_module) t1_stop = perf_counter() print("Elapsed time:", t1_stop - t1_start)
GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ----------------------------------------- 0 | layers | ModuleList | 206 K 1 | dropout | Dropout | 0 2 | train_acc | Accuracy | 0 3 | val_acc | Accuracy | 0 ----------------------------------------- 206 K Trainable params 0 Non-trainable params 206 K Total params 0.828 Total estimated model params size (MB) `Trainer.fit` stopped: `max_epochs=10` reached.
Elapsed time: 117.55090914410539

Evaluation

The prediction/evaluation is also a bit interesting. As explained clearly by DGL Tutorial - Exact Offline Inference on Large Graphs While training our GNN, we often times perform neighborhood sampling for reducing memory. But while performing inferencing, it's better to truly aggregate over all neighbors.

The result of this is that our inference implemention will be slightly different compared to training. During training, we have an outer loop that's iterating over mini-batches of nodes (this is coming from our DataLoader), and an inner loop that's iterating over all our GNN's layer. During inferencing, what will happen is, we instead will have an outer loop that's iterating over the GNN layers, and an inner loop that's iterating over our mini-batches of nodes.

def predict(graph, model, batch_size, device): graph.ndata["h"] = graph.ndata["feat"] sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) data_loader = dgl.dataloading.DataLoader( graph, torch.arange(graph.number_of_nodes()).to(device), sampler, batch_size=batch_size, shuffle=False, drop_last=False, device=device, num_workers=0, use_uva=True ) for l, layer in enumerate(model.layers): y = torch.zeros( graph.num_nodes(), model.n_hidden if l != len(model.layers) - 1 else model.n_classes, device='cpu' ) for input_nodes, output_nodes, blocks in data_loader: block = blocks[0] x = block.srcdata['h'] h = layer(block, x) if l != len(model.layers) - 1: h = F.relu(h) h = model.dropout(h) y[output_nodes] = h.to('cpu') graph.ndata["h"] = y del graph.ndata['h'] return y
predict_batch_size = 4096 with torch.no_grad(): pred = predict(graph, model.to(device), predict_batch_size, device) pred = pred[test_idx] label = graph.ndata["label"][test_idx] accuracy = MF.accuracy(pred, label) accuracy = round(accuracy.item(), 3) print("Test accuracy:", accuracy)
Test accuracy: 0.748

Hopefully, this served as a quick introduction to GNN's node classification task. Feel free to check the leaderboard for potential improvements to this baseline approach.

Reference