Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/graphs/gat/experiment.py
4925 views
1
"""
2
---
3
title: Train a Graph Attention Network (GAT) on Cora dataset
4
summary: >
5
This trains is a Graph Attention Network (GAT) on Cora dataset
6
---
7
8
# Train a Graph Attention Network (GAT) on Cora dataset
9
"""
10
11
from typing import Dict
12
13
import numpy as np
14
import torch
15
from torch import nn
16
17
from labml import lab, monit, tracker, experiment
18
from labml.configs import BaseConfigs, option, calculate
19
from labml.utils import download
20
from labml_nn.helpers.device import DeviceConfigs
21
from labml_nn.graphs.gat import GraphAttentionLayer
22
from labml_nn.optimizers.configs import OptimizerConfigs
23
24
25
class CoraDataset:
26
"""
27
## [Cora Dataset](https://linqs.soe.ucsc.edu/data)
28
29
Cora dataset is a dataset of research papers.
30
For each paper we are given a binary feature vector that indicates the presence of words.
31
Each paper is classified into one of 7 classes.
32
The dataset also has the citation network.
33
34
The papers are the nodes of the graph and the edges are the citations.
35
36
The task is to classify the nodes to the 7 classes with feature vectors and
37
citation network as input.
38
"""
39
# Labels for each node
40
labels: torch.Tensor
41
# Set of class names and an unique integer index
42
classes: Dict[str, int]
43
# Feature vectors for all nodes
44
features: torch.Tensor
45
# Adjacency matrix with the edge information.
46
# `adj_mat[i][j]` is `True` if there is an edge from `i` to `j`.
47
adj_mat: torch.Tensor
48
49
@staticmethod
50
def _download():
51
"""
52
Download the dataset
53
"""
54
if not (lab.get_data_path() / 'cora').exists():
55
download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
56
lab.get_data_path() / 'cora.tgz')
57
download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())
58
59
def __init__(self, include_edges: bool = True):
60
"""
61
Load the dataset
62
"""
63
64
# Whether to include edges.
65
# This is test how much accuracy is lost if we ignore the citation network.
66
self.include_edges = include_edges
67
68
# Download dataset
69
self._download()
70
71
# Read the paper ids, feature vectors, and labels
72
with monit.section('Read content file'):
73
content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))
74
# Load the citations, it's a list of pairs of integers.
75
with monit.section('Read citations file'):
76
citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)
77
78
# Get the feature vectors
79
features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))
80
# Normalize the feature vectors
81
self.features = features / features.sum(dim=1, keepdim=True)
82
83
# Get the class names and assign an unique integer to each of them
84
self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}
85
# Get the labels as those integers
86
self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)
87
88
# Get the paper ids
89
paper_ids = np.array(content[:, 0], dtype=np.int32)
90
# Map of paper id to index
91
ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}
92
93
# Empty adjacency matrix - an identity matrix
94
self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)
95
96
# Mark the citations in the adjacency matrix
97
if self.include_edges:
98
for e in citations:
99
# The pair of paper indexes
100
e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]
101
# We build a symmetrical graph, where if paper $i$ referenced
102
# paper $j$ we place an adge from $i$ to $j$ as well as an edge
103
# from $j$ to $i$.
104
self.adj_mat[e1][e2] = True
105
self.adj_mat[e2][e1] = True
106
107
108
class GAT(nn.Module):
109
"""
110
## Graph Attention Network (GAT)
111
112
This graph attention network has two [graph attention layers](index.html).
113
"""
114
115
def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
116
"""
117
* `in_features` is the number of features per node
118
* `n_hidden` is the number of features in the first graph attention layer
119
* `n_classes` is the number of classes
120
* `n_heads` is the number of heads in the graph attention layers
121
* `dropout` is the dropout probability
122
"""
123
super().__init__()
124
125
# First graph attention layer where we concatenate the heads
126
self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)
127
# Activation function after first graph attention layer
128
self.activation = nn.ELU()
129
# Final graph attention layer where we average the heads
130
self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)
131
# Dropout
132
self.dropout = nn.Dropout(dropout)
133
134
def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):
135
"""
136
* `x` is the features vectors of shape `[n_nodes, in_features]`
137
* `adj_mat` is the adjacency matrix of the form
138
`[n_nodes, n_nodes, n_heads]` or `[n_nodes, n_nodes, 1]`
139
"""
140
# Apply dropout to the input
141
x = self.dropout(x)
142
# First graph attention layer
143
x = self.layer1(x, adj_mat)
144
# Activation function
145
x = self.activation(x)
146
# Dropout
147
x = self.dropout(x)
148
# Output layer (without activation) for logits
149
return self.output(x, adj_mat)
150
151
152
def accuracy(output: torch.Tensor, labels: torch.Tensor):
153
"""
154
A simple function to calculate the accuracy
155
"""
156
return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)
157
158
159
class Configs(BaseConfigs):
160
"""
161
## Configurations
162
"""
163
164
# Model
165
model: GAT
166
# Number of nodes to train on
167
training_samples: int = 500
168
# Number of features per node in the input
169
in_features: int
170
# Number of features in the first graph attention layer
171
n_hidden: int = 64
172
# Number of heads
173
n_heads: int = 8
174
# Number of classes for classification
175
n_classes: int
176
# Dropout probability
177
dropout: float = 0.6
178
# Whether to include the citation network
179
include_edges: bool = True
180
# Dataset
181
dataset: CoraDataset
182
# Number of training iterations
183
epochs: int = 1_000
184
# Loss function
185
loss_func = nn.CrossEntropyLoss()
186
# Device to train on
187
#
188
# This creates configs for device, so that
189
# we can change the device by passing a config value
190
device: torch.device = DeviceConfigs()
191
# Optimizer
192
optimizer: torch.optim.Adam
193
194
def run(self):
195
"""
196
### Training loop
197
198
We do full batch training since the dataset is small.
199
If we were to sample and train we will have to sample a set of
200
nodes for each training step along with the edges that span
201
across those selected nodes.
202
"""
203
# Move the feature vectors to the device
204
features = self.dataset.features.to(self.device)
205
# Move the labels to the device
206
labels = self.dataset.labels.to(self.device)
207
# Move the adjacency matrix to the device
208
edges_adj = self.dataset.adj_mat.to(self.device)
209
# Add an empty third dimension for the heads
210
edges_adj = edges_adj.unsqueeze(-1)
211
212
# Random indexes
213
idx_rand = torch.randperm(len(labels))
214
# Nodes for training
215
idx_train = idx_rand[:self.training_samples]
216
# Nodes for validation
217
idx_valid = idx_rand[self.training_samples:]
218
219
# Training loop
220
for epoch in monit.loop(self.epochs):
221
# Set the model to training mode
222
self.model.train()
223
# Make all the gradients zero
224
self.optimizer.zero_grad()
225
# Evaluate the model
226
output = self.model(features, edges_adj)
227
# Get the loss for training nodes
228
loss = self.loss_func(output[idx_train], labels[idx_train])
229
# Calculate gradients
230
loss.backward()
231
# Take optimization step
232
self.optimizer.step()
233
# Log the loss
234
tracker.add('loss.train', loss)
235
# Log the accuracy
236
tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))
237
238
# Set mode to evaluation mode for validation
239
self.model.eval()
240
241
# No need to compute gradients
242
with torch.no_grad():
243
# Evaluate the model again
244
output = self.model(features, edges_adj)
245
# Calculate the loss for validation nodes
246
loss = self.loss_func(output[idx_valid], labels[idx_valid])
247
# Log the loss
248
tracker.add('loss.valid', loss)
249
# Log the accuracy
250
tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))
251
252
# Save logs
253
tracker.save()
254
255
256
@option(Configs.dataset)
257
def cora_dataset(c: Configs):
258
"""
259
Create Cora dataset
260
"""
261
return CoraDataset(c.include_edges)
262
263
264
# Get the number of classes
265
calculate(Configs.n_classes, lambda c: len(c.dataset.classes))
266
# Number of features in the input
267
calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])
268
269
270
@option(Configs.model)
271
def gat_model(c: Configs):
272
"""
273
Create GAT model
274
"""
275
return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)
276
277
278
@option(Configs.optimizer)
279
def _optimizer(c: Configs):
280
"""
281
Create configurable optimizer
282
"""
283
opt_conf = OptimizerConfigs()
284
opt_conf.parameters = c.model.parameters()
285
return opt_conf
286
287
288
def main():
289
# Create configurations
290
conf = Configs()
291
# Create an experiment
292
experiment.create(name='gat')
293
# Calculate configurations.
294
experiment.configs(conf, {
295
# Adam optimizer
296
'optimizer.optimizer': 'Adam',
297
'optimizer.learning_rate': 5e-3,
298
'optimizer.weight_decay': 5e-4,
299
})
300
301
# Start and watch the experiment
302
with experiment.start():
303
# Run the training
304
conf.run()
305
306
307
#
308
if __name__ == '__main__':
309
main()
310
311