Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
rasbt
GitHub Repository: rasbt/machine-learning-book
Path: blob/main/ch18/ch18_part1.py
1245 views
1
# coding: utf-8
2
3
4
import sys
5
from python_environment_check import check_packages
6
import networkx as nx
7
import numpy as np
8
import torch
9
from torch.nn.parameter import Parameter
10
import torch.nn.functional as F
11
from torch.utils.data import Dataset
12
from torch.utils.data import DataLoader
13
14
# # Machine Learning with PyTorch and Scikit-Learn
15
# # -- Code Examples
16
17
# ## Package version checks
18
19
# Add folder to path in order to load from the check_packages.py script:
20
21
22
23
sys.path.insert(0, '..')
24
25
26
# Check recommended package versions:
27
28
29
30
31
32
d = {
33
'torch': '1.8.0',
34
'networkx': '2.6.2',
35
'numpy': '1.21.2',
36
}
37
38
check_packages(d)
39
40
41
# # Chapter 18 - Graph Neural Networks for Capturing Dependencies in Graph Structured Data (Part 1/2)
42
43
# - [Introduction to graph data](#Introduction-to-graph-data)
44
# - [Undirected graphs](#Undirected-graphs)
45
# - [Directed graphs](#Directed-graphs)
46
# - [Labeled graphs](#Labeled-graphs)
47
# - [Representing molecules as graphs](#Representing-molecules-as-graphs)
48
# - [Understanding graph convolutions](#Understanding-graph-convolutions)
49
# - [The motivation behind using graph convolutions](#The-motivation-behind-using-graph-convolutions)
50
# - [Implementing a basic graph convolution](#Implementing-a-basic-graph-convolution)
51
# - [Implementing a GNN in PyTorch from scratch](#Implementing-a-GNN-in-PyTorch-from-scratch)
52
# - [Defining the NodeNetwork model](#Defining-the-NodeNetwork-model)
53
# - [Coding the NodeNetwork’s graph convolution layer](#Coding-the-NodeNetworks-graph-convolution-layer)
54
# - [Adding a global pooling layer to deal with varying graph sizes](#Adding-a-global-pooling-layer-to-deal-with-varying-graph-sizes)
55
# - [Preparing the DataLoader](#Preparing-the-DataLoader)
56
# - [Using the NodeNetwork to make predictions](#Using-the-NodeNetwork-to-make-predictions)
57
58
59
60
61
62
# ## Introduction to graph data
63
64
65
66
67
68
# ### Undirected graphs
69
70
71
72
73
74
# ### Directed graphs
75
76
77
78
79
80
# ### Labeled graphs
81
82
# ### Representing molecules as graphs
83
84
85
86
87
88
# ### Understanding graph convolutions
89
90
# ### The motivation behind using graph convolutions
91
92
93
94
95
96
# ### Implementing a basic graph convolution
97
98
99
100
101
102
103
104
105
106
107
108
G = nx.Graph()
109
110
#Hex codes for colors if we draw graph
111
blue, orange, green = "#1f77b4", "#ff7f0e","#2ca02c"
112
113
G.add_nodes_from([(1, {"color": blue}),
114
(2, {"color": orange}),
115
(3, {"color": blue}),
116
(4, {"color": green})])
117
118
G.add_edges_from([(1, 2),(2, 3),(1, 3),(3, 4)])
119
A = np.asarray(nx.adjacency_matrix(G).todense())
120
print(A)
121
122
123
124
125
def build_graph_color_label_representation(G,mapping_dict):
126
one_hot_idxs = np.array([mapping_dict[v] for v in
127
nx.get_node_attributes(G, 'color').values()])
128
one_hot_encoding = np.zeros((one_hot_idxs.size,len(mapping_dict)))
129
one_hot_encoding[np.arange(one_hot_idxs.size),one_hot_idxs] = 1
130
return one_hot_encoding
131
132
X = build_graph_color_label_representation(G, {green: 0, blue: 1, orange: 2})
133
print(X)
134
135
136
137
138
color_map = nx.get_node_attributes(G, 'color').values()
139
nx.draw(G, with_labels=True, node_color=color_map)
140
141
142
143
144
145
146
147
148
f_in, f_out = X.shape[1], 6
149
W_1 = np.random.rand(f_in, f_out)
150
W_2 = np.random.rand(f_in, f_out)
151
h = np.dot(X,W_1) + np.dot(np.dot(A, X), W_2)
152
153
154
155
156
157
158
# ## Implementing a GNN in PyTorch from scratch
159
160
# ### Defining the NodeNetwork model
161
162
163
164
165
166
167
168
class NodeNetwork(torch.nn.Module):
169
170
def __init__(self, input_features):
171
super().__init__()
172
173
self.conv_1 = BasicGraphConvolutionLayer(input_features, 32)
174
self.conv_2 = BasicGraphConvolutionLayer(32, 32)
175
self.fc_1 = torch.nn.Linear(32, 16)
176
self.out_layer = torch.nn.Linear(16, 2)
177
178
def forward(self, X, A,batch_mat):
179
x = self.conv_1(X, A).clamp(0)
180
x = self.conv_2(x, A).clamp(0)
181
output = global_sum_pool(x, batch_mat)
182
output = self.fc_1(output)
183
output = self.out_layer(output)
184
return F.softmax(output, dim=1)
185
186
187
188
189
190
191
# ### Coding the NodeNetwork’s graph convolution layer
192
193
194
195
class BasicGraphConvolutionLayer(torch.nn.Module):
196
197
def __init__(self, in_channels, out_channels):
198
super().__init__()
199
self.in_channels = in_channels
200
self.out_channels = out_channels
201
202
self.W2 = Parameter(torch.rand(
203
(in_channels, out_channels), dtype=torch.float32))
204
self.W1 = Parameter(torch.rand(
205
(in_channels, out_channels), dtype=torch.float32))
206
207
self.bias = Parameter(torch.zeros(
208
out_channels, dtype=torch.float32))
209
210
def forward(self, X, A):
211
potential_msgs = torch.mm(X, self.W2)
212
propagated_msgs = torch.mm(A, potential_msgs)
213
root_update = torch.mm(X, self.W1)
214
output = propagated_msgs + root_update + self.bias
215
return output
216
217
218
# ### Adding a global pooling layer to deal with varying graph sizes
219
220
221
222
def global_sum_pool(X, batch_mat):
223
if batch_mat is None or batch_mat.dim() == 1:
224
return torch.sum(X, dim=0).unsqueeze(0)
225
else:
226
return torch.mm(batch_mat, X)
227
228
229
230
231
232
233
234
235
def get_batch_tensor(graph_sizes):
236
starts = [sum(graph_sizes[:idx]) for idx in range(len(graph_sizes))]
237
stops = [starts[idx]+graph_sizes[idx] for idx in range(len(graph_sizes))]
238
tot_len = sum(graph_sizes)
239
batch_size = len(graph_sizes)
240
batch_mat = torch.zeros([batch_size, tot_len]).float()
241
for idx, starts_and_stops in enumerate(zip(starts, stops)):
242
start = starts_and_stops[0]
243
stop = starts_and_stops[1]
244
batch_mat[idx, start:stop] = 1
245
return batch_mat
246
247
248
249
250
def collate_graphs(batch):
251
adj_mats = [graph['A'] for graph in batch]
252
sizes = [A.size(0) for A in adj_mats]
253
tot_size = sum(sizes)
254
# create batch matrix
255
batch_mat = get_batch_tensor(sizes)
256
# combine feature matrices
257
feat_mats = torch.cat([graph['X'] for graph in batch],dim=0)
258
# combine labels
259
labels = torch.cat([graph['y'] for graph in batch], dim=0)
260
# combine adjacency matrices
261
batch_adj = torch.zeros([tot_size, tot_size], dtype=torch.float32)
262
accum = 0
263
for adj in adj_mats:
264
g_size = adj.shape[0]
265
batch_adj[accum:accum+g_size, accum:accum+g_size] = adj
266
accum = accum + g_size
267
repr_and_label = {
268
'A': batch_adj,
269
'X': feat_mats,
270
'y': labels,
271
'batch' : batch_mat}
272
273
return repr_and_label
274
275
276
# ### Preparing the DataLoader
277
278
279
280
def get_graph_dict(G, mapping_dict):
281
# build dictionary representation of graph G
282
A = torch.from_numpy(np.asarray(nx.adjacency_matrix(G).todense())).float()
283
# build_graph_color_label_representation() was introduced with the first example graph
284
X = torch.from_numpy(build_graph_color_label_representation(G,mapping_dict)).float()
285
# kludge since there is not specific task for this example
286
y = torch.tensor([[1, 0]]).float()
287
return {'A': A, 'X': X, 'y': y, 'batch': None}
288
289
# building 4 graphs to treat as a dataset
290
291
blue, orange, green = "#1f77b4", "#ff7f0e","#2ca02c"
292
mapping_dict = {green: 0, blue: 1, orange: 2}
293
294
G1 = nx.Graph()
295
G1.add_nodes_from([(1, {"color": blue}),
296
(2, {"color": orange}),
297
(3, {"color": blue}),
298
(4, {"color": green})])
299
G1.add_edges_from([(1, 2), (2, 3),(1, 3), (3, 4)])
300
G2 = nx.Graph()
301
G2.add_nodes_from([(1, {"color": green}),
302
(2, {"color": green}),
303
(3, {"color": orange}),
304
(4, {"color": orange}),
305
(5,{"color": blue})])
306
G2.add_edges_from([(2, 3),(3, 4),(3, 1),(5, 1)])
307
G3 = nx.Graph()
308
G3.add_nodes_from([(1, {"color": orange}),
309
(2, {"color": orange}),
310
(3, {"color": green}),
311
(4, {"color": green}),
312
(5, {"color": blue}),
313
(6, {"color":orange})])
314
G3.add_edges_from([(2, 3), (3, 4), (3, 1), (5, 1), (2, 5), (6, 1)])
315
G4 = nx.Graph()
316
G4.add_nodes_from([(1, {"color": blue}), (2, {"color": blue}), (3, {"color": green})])
317
G4.add_edges_from([(1, 2), (2, 3)])
318
graph_list = [get_graph_dict(graph,mapping_dict) for graph in [G1, G2, G3, G4]]
319
320
321
322
323
324
325
326
327
328
329
class ExampleDataset(Dataset):
330
331
# Simple PyTorch dataset that will use our list of graphs
332
def __init__(self, graph_list):
333
self.graphs = graph_list
334
335
def __len__(self):
336
return len(self.graphs)
337
338
def __getitem__(self,idx):
339
mol_rep = self.graphs[idx]
340
return mol_rep
341
342
343
344
345
dset = ExampleDataset(graph_list)
346
# Note how we use our custom collate function
347
loader = DataLoader(dset, batch_size=2, shuffle=False, collate_fn=collate_graphs)
348
349
350
# ### Using the NodeNetwork to make predictions
351
352
353
354
torch.manual_seed(123)
355
node_features = 3
356
net = NodeNetwork(node_features)
357
358
359
360
361
batch_results = []
362
363
for b in loader:
364
batch_results.append(net(b['X'], b['A'], b['batch']).detach())
365
366
G1_rep = dset[1]
367
G1_single = net(G1_rep['X'], G1_rep['A'], G1_rep['batch']).detach()
368
369
G1_batch = batch_results[0][1]
370
torch.all(torch.isclose(G1_single, G1_batch))
371
372
373
# ---
374
#
375
# Readers may ignore the next cell.
376
377
378
379
380
381
382
383
384
385
386