import sys
from python_environment_check import check_packages
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import NNConv, global_add_pool
import numpy as np
from torch.utils.data import random_split
import matplotlib.pyplot as plt
sys.path.insert(0, '..')
d = {
'torch': '1.8.0',
'torch_geometric': '2.0.2',
'numpy': '1.21.2',
'matplotlib': '3.4.3',
}
check_packages(d)
dset = QM9('.')
len(dset)
data = dset[0]
data
data.z
data.new_attribute = torch.tensor([1, 2, 3])
data
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data.to(device)
data.new_attribute.is_cuda
class ExampleNet(torch.nn.Module):
def __init__(self,num_node_features,num_edge_features):
super().__init__()
conv1_net = nn.Sequential(nn.Linear(num_edge_features, 32),
nn.ReLU(),
nn.Linear(32, num_node_features*32))
conv2_net = nn.Sequential(nn.Linear(num_edge_features,32),
nn.ReLU(),
nn.Linear(32, 32*16))
self.conv1 = NNConv(num_node_features, 32, conv1_net)
self.conv2 = NNConv(32, 16, conv2_net)
self.fc_1 = nn.Linear(16, 32)
self.out = nn.Linear(32, 1)
def forward(self, data):
batch, x, edge_index, edge_attr=data.batch, data.x, data.edge_index, data.edge_attr
x = F.relu(self.conv1(x, edge_index, edge_attr))
x = F.relu(self.conv2(x, edge_index, edge_attr))
x = global_add_pool(x,batch)
x = F.relu(self.fc_1(x))
output = self.out(x)
return output
train_set, valid_set, test_set = random_split(dset,[110000, 10831, 10000])
trainloader = DataLoader(train_set, batch_size=32, shuffle=True)
validloader = DataLoader(valid_set, batch_size=32, shuffle=True)
testloader = DataLoader(test_set, batch_size=32, shuffle=True)
qm9_node_feats, qm9_edge_feats = 11, 4
epochs = 4
net = ExampleNet(qm9_node_feats, qm9_edge_feats)
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
epochs = 4
target_idx = 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
for total_epochs in range(epochs):
epoch_loss = 0
total_graphs = 0
net.train()
for batch in trainloader:
batch.to(device)
optimizer.zero_grad()
output = net(batch)
loss = F.mse_loss(output, batch.y[:, target_idx].unsqueeze(1))
loss.backward()
epoch_loss += loss.item()
total_graphs += batch.num_graphs
optimizer.step()
train_avg_loss = epoch_loss / total_graphs
val_loss = 0
total_graphs = 0
net.eval()
for batch in validloader:
batch.to(device)
output = net(batch)
loss = F.mse_loss(output,batch.y[:, target_idx].unsqueeze(1))
val_loss += loss.item()
total_graphs += batch.num_graphs
val_avg_loss = val_loss / total_graphs
print(f"Epochs: {total_epochs} | epoch avg. loss: {train_avg_loss:.2f} | validation avg. loss: {val_avg_loss:.2f}")
net.eval()
predictions = []
real = []
for batch in testloader:
output = net(batch.to(device))
predictions.append(output.detach().cpu().numpy())
real.append(batch.y[:, target_idx].detach().cpu().numpy())
predictions = np.concatenate(predictions)
real = np.concatenate(real)
plt.scatter(real[:500],predictions[:500])
plt.ylabel('Predicted isotropic polarizability')
plt.xlabel('Isotropic polarizability')