Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/neural_tangent_kernels.py
3704 views
1
# -*- coding: utf-8 -*-
2
"""
3
Neural Tangent Kernels
4
======================
5
6
The neural tangent kernel (NTK) is a kernel that describes
7
`how a neural network evolves during training <https://en.wikipedia.org/wiki/Neural_tangent_kernel>`_.
8
There has been a lot of research around it `in recent years <https://arxiv.org/abs/1806.07572>`_.
9
This tutorial, inspired by the implementation of `NTKs in JAX <https://github.com/google/neural-tangents>`_
10
(see `Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_ for details),
11
demonstrates how to easily compute this quantity using ``torch.func``,
12
composable function transforms for PyTorch.
13
14
.. note::
15
16
This tutorial requires PyTorch 2.6.0 or later.
17
18
Setup
19
-----
20
21
First, some setup. Let's define a simple CNN that we wish to compute the NTK of.
22
"""
23
24
import torch
25
import torch.nn as nn
26
from torch.func import functional_call, vmap, vjp, jvp, jacrev
27
28
if torch.accelerator.is_available() and torch.accelerator.device_count() > 0:
29
device = torch.accelerator.current_accelerator()
30
else:
31
device = torch.device("cpu")
32
33
34
class CNN(nn.Module):
35
def __init__(self):
36
super(CNN, self).__init__()
37
self.conv1 = nn.Conv2d(3, 32, (3, 3))
38
self.conv2 = nn.Conv2d(32, 32, (3, 3))
39
self.conv3 = nn.Conv2d(32, 32, (3, 3))
40
self.fc = nn.Linear(21632, 10)
41
42
def forward(self, x):
43
x = self.conv1(x)
44
x = x.relu()
45
x = self.conv2(x)
46
x = x.relu()
47
x = self.conv3(x)
48
x = x.flatten(1)
49
x = self.fc(x)
50
return x
51
52
######################################################################
53
# And let's generate some random data
54
55
x_train = torch.randn(20, 3, 32, 32, device=device)
56
x_test = torch.randn(5, 3, 32, 32, device=device)
57
58
######################################################################
59
# Create a function version of the model
60
# --------------------------------------
61
#
62
# ``torch.func`` transforms operate on functions. In particular, to compute the NTK,
63
# we will need a function that accepts the parameters of the model and a single
64
# input (as opposed to a batch of inputs!) and returns a single output.
65
#
66
# We'll use ``torch.func.functional_call``, which allows us to call an ``nn.Module``
67
# using different parameters/buffers, to help accomplish the first step.
68
#
69
# Keep in mind that the model was originally written to accept a batch of input
70
# data points. In our CNN example, there are no inter-batch operations. That
71
# is, each data point in the batch is independent of other data points. With
72
# this assumption in mind, we can easily generate a function that evaluates the
73
# model on a single data point:
74
75
76
net = CNN().to(device)
77
78
# Detaching the parameters because we won't be calling Tensor.backward().
79
params = {k: v.detach() for k, v in net.named_parameters()}
80
81
def fnet_single(params, x):
82
return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)
83
84
######################################################################
85
# Compute the NTK: method 1 (Jacobian contraction)
86
# ------------------------------------------------
87
# We're ready to compute the empirical NTK. The empirical NTK for two data
88
# points :math:`x_1` and :math:`x_2` is defined as the matrix product between the Jacobian
89
# of the model evaluated at :math:`x_1` and the Jacobian of the model evaluated at
90
# :math:`x_2`:
91
#
92
# .. math::
93
#
94
# J_{net}(x_1) J_{net}^T(x_2)
95
#
96
# In the batched case where :math:`x_1` is a batch of data points and :math:`x_2` is a
97
# batch of data points, then we want the matrix product between the Jacobians
98
# of all combinations of data points from :math:`x_1` and :math:`x_2`.
99
#
100
# The first method consists of doing just that - computing the two Jacobians,
101
# and contracting them. Here's how to compute the NTK in the batched case:
102
103
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
104
# Compute J(x1)
105
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
106
jac1 = jac1.values()
107
jac1 = [j.flatten(2) for j in jac1]
108
109
# Compute J(x2)
110
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
111
jac2 = jac2.values()
112
jac2 = [j.flatten(2) for j in jac2]
113
114
# Compute J(x1) @ J(x2).T
115
result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
116
result = result.sum(0)
117
return result
118
119
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)
120
print(result.shape)
121
122
######################################################################
123
# In some cases, you may only want the diagonal or the trace of this quantity,
124
# especially if you know beforehand that the network architecture results in an
125
# NTK where the non-diagonal elements can be approximated by zero. It's easy to
126
# adjust the above function to do that:
127
128
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):
129
# Compute J(x1)
130
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
131
jac1 = jac1.values()
132
jac1 = [j.flatten(2) for j in jac1]
133
134
# Compute J(x2)
135
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
136
jac2 = jac2.values()
137
jac2 = [j.flatten(2) for j in jac2]
138
139
# Compute J(x1) @ J(x2).T
140
einsum_expr = None
141
if compute == 'full':
142
einsum_expr = 'Naf,Mbf->NMab'
143
elif compute == 'trace':
144
einsum_expr = 'Naf,Maf->NM'
145
elif compute == 'diagonal':
146
einsum_expr = 'Naf,Maf->NMa'
147
else:
148
assert False
149
150
result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
151
result = result.sum(0)
152
return result
153
154
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')
155
print(result.shape)
156
157
######################################################################
158
# The asymptotic time complexity of this method is :math:`N O [FP]` (time to
159
# compute the Jacobians) + :math:`N^2 O^2 P` (time to contract the Jacobians),
160
# where :math:`N` is the batch size of :math:`x_1` and :math:`x_2`, :math:`O`
161
# is the model's output size, :math:`P` is the total number of parameters, and
162
# :math:`[FP]` is the cost of a single forward pass through the model. See
163
# section 3.2 in
164
# `Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_
165
# for details.
166
#
167
# Compute the NTK: method 2 (NTK-vector products)
168
# -----------------------------------------------
169
#
170
# The next method we will discuss is a way to compute the NTK using NTK-vector
171
# products.
172
#
173
# This method reformulates NTK as a stack of NTK-vector products applied to
174
# columns of an identity matrix :math:`I_O` of size :math:`O\times O`
175
# (where :math:`O` is the output size of the model):
176
#
177
# .. math::
178
#
179
# J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \left[J_{net}(x_1) \left[J_{net}^T(x_2) e_o\right]\right]_{o=1}^{O},
180
#
181
# where :math:`e_o\in \mathbb{R}^O` are column vectors of the identity matrix
182
# :math:`I_O`.
183
#
184
# - Let :math:`\textrm{vjp}_o = J_{net}^T(x_2) e_o`. We can use
185
# a vector-Jacobian product to compute this.
186
# - Now, consider :math:`J_{net}(x_1) \textrm{vjp}_o`. This is a
187
# Jacobian-vector product!
188
# - Finally, we can run the above computation in parallel over all
189
# columns :math:`e_o` of :math:`I_O` using ``vmap``.
190
#
191
# This suggests that we can use a combination of reverse-mode AD (to compute
192
# the vector-Jacobian product) and forward-mode AD (to compute the
193
# Jacobian-vector product) to compute the NTK.
194
#
195
# Let's code that up:
196
197
def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):
198
def get_ntk(x1, x2):
199
def func_x1(params):
200
return func(params, x1)
201
202
def func_x2(params):
203
return func(params, x2)
204
205
output, vjp_fn = vjp(func_x1, params)
206
207
def get_ntk_slice(vec):
208
# This computes ``vec @ J(x2).T``
209
# `vec` is some unit vector (a single slice of the Identity matrix)
210
vjps = vjp_fn(vec)
211
# This computes ``J(X1) @ vjps``
212
_, jvps = jvp(func_x2, (params,), vjps)
213
return jvps
214
215
# Here's our identity matrix
216
basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)
217
return vmap(get_ntk_slice)(basis)
218
219
# ``get_ntk(x1, x2)`` computes the NTK for a single data point x1, x2
220
# Since the x1, x2 inputs to ``empirical_ntk_ntk_vps`` are batched,
221
# we actually wish to compute the NTK between every pair of data points
222
# between {x1} and {x2}. That's what the ``vmaps`` here do.
223
result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)
224
225
if compute == 'full':
226
return result
227
if compute == 'trace':
228
return torch.einsum('NMKK->NM', result)
229
if compute == 'diagonal':
230
return torch.einsum('NMKK->NMK', result)
231
232
# Disable TensorFloat-32 for convolutions on Ampere+ GPUs to sacrifice performance in favor of accuracy
233
with torch.backends.cudnn.flags(allow_tf32=False):
234
result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)
235
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)
236
237
assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)
238
239
######################################################################
240
# Our code for ``empirical_ntk_ntk_vps`` looks like a direct translation from
241
# the math above! This showcases the power of function transforms: good luck
242
# trying to write an efficient version of the above by only using
243
# ``torch.autograd.grad``.
244
#
245
# The asymptotic time complexity of this method is :math:`N^2 O [FP]`, where
246
# :math:`N` is the batch size of :math:`x_1` and :math:`x_2`, :math:`O` is the
247
# model's output size, and :math:`[FP]` is the cost of a single forward pass
248
# through the model. Hence this method performs more forward passes through the
249
# network than method 1, Jacobian contraction (:math:`N^2 O` instead of
250
# :math:`N O`), but avoids the contraction cost altogether (no :math:`N^2 O^2 P`
251
# term, where :math:`P` is the total number of model's parameters). Therefore,
252
# this method is preferable when :math:`O P` is large relative to :math:`[FP]`,
253
# such as fully-connected (not convolutional) models with many outputs :math:`O`.
254
# Memory-wise, both methods should be comparable. See section 3.3 in
255
# `Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_
256
# for details.
257
258