Path: blob/master/labml_nn/graphs/gatv2/experiment.py
4950 views
"""1---2title: Train a Graph Attention Network v2 (GATv2) on Cora dataset3summary: >4This trains is a Graph Attention Network v2 (GATv2) on Cora dataset5---67# Train a Graph Attention Network v2 (GATv2) on Cora dataset8"""910import torch11from torch import nn1213from labml import experiment14from labml.configs import option15from labml_nn.graphs.gat.experiment import Configs as GATConfigs16from labml_nn.graphs.gatv2 import GraphAttentionV2Layer171819class GATv2(nn.Module):20"""21## Graph Attention Network v2 (GATv2)2223This graph attention network has two [graph attention layers](index.html).24"""2526def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float,27share_weights: bool = True):28"""29* `in_features` is the number of features per node30* `n_hidden` is the number of features in the first graph attention layer31* `n_classes` is the number of classes32* `n_heads` is the number of heads in the graph attention layers33* `dropout` is the dropout probability34* `share_weights` if set to True, the same matrix will be applied to the source and the target node of every edge35"""36super().__init__()3738# First graph attention layer where we concatenate the heads39self.layer1 = GraphAttentionV2Layer(in_features, n_hidden, n_heads,40is_concat=True, dropout=dropout, share_weights=share_weights)41# Activation function after first graph attention layer42self.activation = nn.ELU()43# Final graph attention layer where we average the heads44self.output = GraphAttentionV2Layer(n_hidden, n_classes, 1,45is_concat=False, dropout=dropout, share_weights=share_weights)46# Dropout47self.dropout = nn.Dropout(dropout)4849def forward(self, x: torch.Tensor, adj_mat: torch.Tensor):50"""51* `x` is the features vectors of shape `[n_nodes, in_features]`52* `adj_mat` is the adjacency matrix of the form53`[n_nodes, n_nodes, n_heads]` or `[n_nodes, n_nodes, 1]`54"""55# Apply dropout to the input56x = self.dropout(x)57# First graph attention layer58x = self.layer1(x, adj_mat)59# Activation function60x = self.activation(x)61# Dropout62x = self.dropout(x)63# Output layer (without activation) for logits64return self.output(x, adj_mat)656667class Configs(GATConfigs):68"""69## Configurations7071Since the experiment is same as [GAT experiment](../gat/experiment.html) but with72[GATv2 model](index.html) we extend the same configs and change the model.73"""7475# Whether to share weights for source and target nodes of edges76share_weights: bool = False77# Set the model78model: GATv2 = 'gat_v2_model'798081@option(Configs.model)82def gat_v2_model(c: Configs):83"""84Create GATv2 model85"""86return GATv2(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout, c.share_weights).to(c.device)878889def main():90# Create configurations91conf = Configs()92# Create an experiment93experiment.create(name='gatv2')94# Calculate configurations.95experiment.configs(conf, {96# Adam optimizer97'optimizer.optimizer': 'Adam',98'optimizer.learning_rate': 5e-3,99'optimizer.weight_decay': 5e-4,100101'dropout': 0.7,102})103104# Start and watch the experiment105with experiment.start():106# Run the training107conf.run()108109110#111if __name__ == '__main__':112main()113114115