Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
rasbt
GitHub Repository: rasbt/machine-learning-book
Path: blob/main/ch18/ch18_part2.py
1245 views
1
# coding: utf-8
2
3
4
import sys
5
from python_environment_check import check_packages
6
import torch
7
import torch.nn.functional as F
8
import torch.nn as nn
9
from torch_geometric.datasets import QM9
10
from torch_geometric.loader import DataLoader
11
from torch_geometric.nn import NNConv, global_add_pool
12
import numpy as np
13
from torch.utils.data import random_split
14
import matplotlib.pyplot as plt
15
16
# # Machine Learning with PyTorch and Scikit-Learn
17
# # -- Code Examples
18
19
# ## Package version checks
20
21
# Add folder to path in order to load from the check_packages.py script:
22
23
24
25
sys.path.insert(0, '..')
26
27
28
# Check recommended package versions:
29
30
31
32
33
34
d = {
35
'torch': '1.8.0',
36
'torch_geometric': '2.0.2',
37
'numpy': '1.21.2',
38
'matplotlib': '3.4.3',
39
}
40
41
check_packages(d)
42
43
44
# # Chapter 18 - Graph Neural Networks for Capturing Dependencies in Graph Structured Data (Part 2/2)
45
46
# - [Implementing a GNN using the PyTorch Geometric library](#Implementing-a-GNN-using-the-PyTorch-Geometric-library)
47
# - [Other GNN layers and recent developments](#Other-GNN-layers-and-recent-developments)
48
# - [Spectral graph convolutions](#Spectral-graph-convolutions)
49
# - [Pooling](#Pooling)
50
# - [Normalization](#Normalization)
51
# - [Pointers to advanced graph neural network literature](#Pointers-to-advanced-graph-neural-network-literature)
52
# - [Summary](#Summary)
53
54
55
56
57
58
# ## Implementing a GNN using the PyTorch Geometric library
59
60
61
62
63
64
65
66
67
68
69
dset = QM9('.')
70
len(dset)
71
72
73
74
75
data = dset[0]
76
data
77
78
79
80
81
data.z
82
83
84
85
86
data.new_attribute = torch.tensor([1, 2, 3])
87
data
88
89
90
91
92
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
93
data.to(device)
94
data.new_attribute.is_cuda
95
96
97
98
99
class ExampleNet(torch.nn.Module):
100
def __init__(self,num_node_features,num_edge_features):
101
super().__init__()
102
conv1_net = nn.Sequential(nn.Linear(num_edge_features, 32),
103
nn.ReLU(),
104
nn.Linear(32, num_node_features*32))
105
conv2_net = nn.Sequential(nn.Linear(num_edge_features,32),
106
nn.ReLU(),
107
nn.Linear(32, 32*16))
108
self.conv1 = NNConv(num_node_features, 32, conv1_net)
109
self.conv2 = NNConv(32, 16, conv2_net)
110
self.fc_1 = nn.Linear(16, 32)
111
self.out = nn.Linear(32, 1)
112
113
def forward(self, data):
114
batch, x, edge_index, edge_attr=data.batch, data.x, data.edge_index, data.edge_attr
115
x = F.relu(self.conv1(x, edge_index, edge_attr))
116
x = F.relu(self.conv2(x, edge_index, edge_attr))
117
x = global_add_pool(x,batch)
118
x = F.relu(self.fc_1(x))
119
output = self.out(x)
120
return output
121
122
123
124
125
126
127
train_set, valid_set, test_set = random_split(dset,[110000, 10831, 10000])
128
129
trainloader = DataLoader(train_set, batch_size=32, shuffle=True)
130
validloader = DataLoader(valid_set, batch_size=32, shuffle=True)
131
testloader = DataLoader(test_set, batch_size=32, shuffle=True)
132
133
134
135
136
qm9_node_feats, qm9_edge_feats = 11, 4
137
epochs = 4
138
net = ExampleNet(qm9_node_feats, qm9_edge_feats)
139
140
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
141
epochs = 4
142
target_idx = 1 # index position of the polarizability label
143
144
145
146
147
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
148
net.to(device)
149
150
151
152
153
for total_epochs in range(epochs):
154
155
epoch_loss = 0
156
total_graphs = 0
157
net.train()
158
for batch in trainloader:
159
batch.to(device)
160
optimizer.zero_grad()
161
output = net(batch)
162
loss = F.mse_loss(output, batch.y[:, target_idx].unsqueeze(1))
163
loss.backward()
164
epoch_loss += loss.item()
165
total_graphs += batch.num_graphs
166
optimizer.step()
167
168
train_avg_loss = epoch_loss / total_graphs
169
val_loss = 0
170
total_graphs = 0
171
net.eval()
172
for batch in validloader:
173
batch.to(device)
174
output = net(batch)
175
loss = F.mse_loss(output,batch.y[:, target_idx].unsqueeze(1))
176
val_loss += loss.item()
177
total_graphs += batch.num_graphs
178
val_avg_loss = val_loss / total_graphs
179
180
181
print(f"Epochs: {total_epochs} | epoch avg. loss: {train_avg_loss:.2f} | validation avg. loss: {val_avg_loss:.2f}")
182
183
184
185
186
net.eval()
187
predictions = []
188
real = []
189
190
for batch in testloader:
191
192
output = net(batch.to(device))
193
predictions.append(output.detach().cpu().numpy())
194
real.append(batch.y[:, target_idx].detach().cpu().numpy())
195
196
predictions = np.concatenate(predictions)
197
real = np.concatenate(real)
198
199
200
201
202
203
204
plt.scatter(real[:500],predictions[:500])
205
plt.ylabel('Predicted isotropic polarizability')
206
plt.xlabel('Isotropic polarizability')
207
#plt.savefig('figures/18_12.png', dpi=300)
208
209
210
# ## Other GNN layers and recent developments
211
212
# ### Spectral graph convolutions
213
214
# ### Pooling
215
216
217
218
219
220
# ### Normalization
221
222
# ### Pointers to advanced graph neural network literature
223
224
# ## Summary
225
226
# ---
227
#
228
# Readers may ignore the next cell.
229
230
231
232
233
234