Path: blob/master/labml_nn/graphs/gat/experiment.py
4925 views
"""1---2title: Train a Graph Attention Network (GAT) on Cora dataset3summary: >4This trains is a Graph Attention Network (GAT) on Cora dataset5---67# Train a Graph Attention Network (GAT) on Cora dataset8"""910from typing import Dict1112import numpy as np13import torch14from torch import nn1516from labml import lab, monit, tracker, experiment17from labml.configs import BaseConfigs, option, calculate18from labml.utils import download19from labml_nn.helpers.device import DeviceConfigs20from labml_nn.graphs.gat import GraphAttentionLayer21from labml_nn.optimizers.configs import OptimizerConfigs222324class CoraDataset:25"""26## [Cora Dataset](https://linqs.soe.ucsc.edu/data)2728Cora dataset is a dataset of research papers.29For each paper we are given a binary feature vector that indicates the presence of words.30Each paper is classified into one of 7 classes.31The dataset also has the citation network.3233The papers are the nodes of the graph and the edges are the citations.3435The task is to classify the nodes to the 7 classes with feature vectors and36citation network as input.37"""38# Labels for each node39labels: torch.Tensor40# Set of class names and an unique integer index41classes: Dict[str, int]42# Feature vectors for all nodes43features: torch.Tensor44# Adjacency matrix with the edge information.45# `adj_mat[i][j]` is `True` if there is an edge from `i` to `j`.46adj_mat: torch.Tensor4748@staticmethod49def _download():50"""51Download the dataset52"""53if not (lab.get_data_path() / 'cora').exists():54download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',55lab.get_data_path() / 'cora.tgz')56download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())5758def __init__(self, include_edges: bool = True):59"""60Load the dataset61"""6263# Whether to include edges.64# This is test how much accuracy is lost if we ignore the citation network.65self.include_edges = include_edges6667# Download dataset68self._download()6970# Read the paper ids, feature vectors, and labels71with monit.section('Read content file'):72content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))73# Load the citations, it's a list of pairs of integers.74with monit.section('Read citations file'):75citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)7677# Get the feature vectors78features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))79# Normalize the feature vectors80self.features = features / features.sum(dim=1, keepdim=True)8182# Get the class names and assign an unique integer to each of them83self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}84# Get the labels as those integers85self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)8687# Get the paper ids88paper_ids = np.array(content[:, 0], dtype=np.int32)89# Map of paper id to index90ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}9192# Empty adjacency matrix - an identity matrix93self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)9495# Mark the citations in the adjacency matrix96if self.include_edges:97for e in citations:98# The pair of paper indexes99e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]100# We build a symmetrical graph, where if paper $i$ referenced101# paper $j$ we place an adge from $i$ to $j$ as well as an edge102# from $j$ to $i$.103self.adj_mat[e1][e2] = True104self.adj_mat[e2][e1] = True105106107class GAT(nn.Module):108"""109## Graph Attention Network (GAT)110111This graph attention network has two [graph attention layers](index.html).112"""113114def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):115"""116* `in_features` is the number of features per node117* `n_hidden` is the number of features in the first graph attention layer118* `n_classes` is the number of classes119* `n_heads` is the number of heads in the graph attention layers120* `dropout` is the dropout probability121"""122super().__init__()123124# First graph attention layer where we concatenate the heads125self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)126# Activation function after first graph attention layer127self.activation = nn.ELU()128# Final graph attention layer where we average the heads129self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)130# Dropout131self.dropout = nn.Dropout(dropout)132133def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):134"""135* `x` is the features vectors of shape `[n_nodes, in_features]`136* `adj_mat` is the adjacency matrix of the form137`[n_nodes, n_nodes, n_heads]` or `[n_nodes, n_nodes, 1]`138"""139# Apply dropout to the input140x = self.dropout(x)141# First graph attention layer142x = self.layer1(x, adj_mat)143# Activation function144x = self.activation(x)145# Dropout146x = self.dropout(x)147# Output layer (without activation) for logits148return self.output(x, adj_mat)149150151def accuracy(output: torch.Tensor, labels: torch.Tensor):152"""153A simple function to calculate the accuracy154"""155return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)156157158class Configs(BaseConfigs):159"""160## Configurations161"""162163# Model164model: GAT165# Number of nodes to train on166training_samples: int = 500167# Number of features per node in the input168in_features: int169# Number of features in the first graph attention layer170n_hidden: int = 64171# Number of heads172n_heads: int = 8173# Number of classes for classification174n_classes: int175# Dropout probability176dropout: float = 0.6177# Whether to include the citation network178include_edges: bool = True179# Dataset180dataset: CoraDataset181# Number of training iterations182epochs: int = 1_000183# Loss function184loss_func = nn.CrossEntropyLoss()185# Device to train on186#187# This creates configs for device, so that188# we can change the device by passing a config value189device: torch.device = DeviceConfigs()190# Optimizer191optimizer: torch.optim.Adam192193def run(self):194"""195### Training loop196197We do full batch training since the dataset is small.198If we were to sample and train we will have to sample a set of199nodes for each training step along with the edges that span200across those selected nodes.201"""202# Move the feature vectors to the device203features = self.dataset.features.to(self.device)204# Move the labels to the device205labels = self.dataset.labels.to(self.device)206# Move the adjacency matrix to the device207edges_adj = self.dataset.adj_mat.to(self.device)208# Add an empty third dimension for the heads209edges_adj = edges_adj.unsqueeze(-1)210211# Random indexes212idx_rand = torch.randperm(len(labels))213# Nodes for training214idx_train = idx_rand[:self.training_samples]215# Nodes for validation216idx_valid = idx_rand[self.training_samples:]217218# Training loop219for epoch in monit.loop(self.epochs):220# Set the model to training mode221self.model.train()222# Make all the gradients zero223self.optimizer.zero_grad()224# Evaluate the model225output = self.model(features, edges_adj)226# Get the loss for training nodes227loss = self.loss_func(output[idx_train], labels[idx_train])228# Calculate gradients229loss.backward()230# Take optimization step231self.optimizer.step()232# Log the loss233tracker.add('loss.train', loss)234# Log the accuracy235tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))236237# Set mode to evaluation mode for validation238self.model.eval()239240# No need to compute gradients241with torch.no_grad():242# Evaluate the model again243output = self.model(features, edges_adj)244# Calculate the loss for validation nodes245loss = self.loss_func(output[idx_valid], labels[idx_valid])246# Log the loss247tracker.add('loss.valid', loss)248# Log the accuracy249tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))250251# Save logs252tracker.save()253254255@option(Configs.dataset)256def cora_dataset(c: Configs):257"""258Create Cora dataset259"""260return CoraDataset(c.include_edges)261262263# Get the number of classes264calculate(Configs.n_classes, lambda c: len(c.dataset.classes))265# Number of features in the input266calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])267268269@option(Configs.model)270def gat_model(c: Configs):271"""272Create GAT model273"""274return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)275276277@option(Configs.optimizer)278def _optimizer(c: Configs):279"""280Create configurable optimizer281"""282opt_conf = OptimizerConfigs()283opt_conf.parameters = c.model.parameters()284return opt_conf285286287def main():288# Create configurations289conf = Configs()290# Create an experiment291experiment.create(name='gat')292# Calculate configurations.293experiment.configs(conf, {294# Adam optimizer295'optimizer.optimizer': 'Adam',296'optimizer.learning_rate': 5e-3,297'optimizer.weight_decay': 5e-4,298})299300# Start and watch the experiment301with experiment.start():302# Run the training303conf.run()304305306#307if __name__ == '__main__':308main()309310311