Path: blob/main/intermediate_source/neural_tangent_kernels.py
3704 views
# -*- coding: utf-8 -*-1"""2Neural Tangent Kernels3======================45The neural tangent kernel (NTK) is a kernel that describes6`how a neural network evolves during training <https://en.wikipedia.org/wiki/Neural_tangent_kernel>`_.7There has been a lot of research around it `in recent years <https://arxiv.org/abs/1806.07572>`_.8This tutorial, inspired by the implementation of `NTKs in JAX <https://github.com/google/neural-tangents>`_9(see `Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_ for details),10demonstrates how to easily compute this quantity using ``torch.func``,11composable function transforms for PyTorch.1213.. note::1415This tutorial requires PyTorch 2.6.0 or later.1617Setup18-----1920First, some setup. Let's define a simple CNN that we wish to compute the NTK of.21"""2223import torch24import torch.nn as nn25from torch.func import functional_call, vmap, vjp, jvp, jacrev2627if torch.accelerator.is_available() and torch.accelerator.device_count() > 0:28device = torch.accelerator.current_accelerator()29else:30device = torch.device("cpu")313233class CNN(nn.Module):34def __init__(self):35super(CNN, self).__init__()36self.conv1 = nn.Conv2d(3, 32, (3, 3))37self.conv2 = nn.Conv2d(32, 32, (3, 3))38self.conv3 = nn.Conv2d(32, 32, (3, 3))39self.fc = nn.Linear(21632, 10)4041def forward(self, x):42x = self.conv1(x)43x = x.relu()44x = self.conv2(x)45x = x.relu()46x = self.conv3(x)47x = x.flatten(1)48x = self.fc(x)49return x5051######################################################################52# And let's generate some random data5354x_train = torch.randn(20, 3, 32, 32, device=device)55x_test = torch.randn(5, 3, 32, 32, device=device)5657######################################################################58# Create a function version of the model59# --------------------------------------60#61# ``torch.func`` transforms operate on functions. In particular, to compute the NTK,62# we will need a function that accepts the parameters of the model and a single63# input (as opposed to a batch of inputs!) and returns a single output.64#65# We'll use ``torch.func.functional_call``, which allows us to call an ``nn.Module``66# using different parameters/buffers, to help accomplish the first step.67#68# Keep in mind that the model was originally written to accept a batch of input69# data points. In our CNN example, there are no inter-batch operations. That70# is, each data point in the batch is independent of other data points. With71# this assumption in mind, we can easily generate a function that evaluates the72# model on a single data point:737475net = CNN().to(device)7677# Detaching the parameters because we won't be calling Tensor.backward().78params = {k: v.detach() for k, v in net.named_parameters()}7980def fnet_single(params, x):81return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)8283######################################################################84# Compute the NTK: method 1 (Jacobian contraction)85# ------------------------------------------------86# We're ready to compute the empirical NTK. The empirical NTK for two data87# points :math:`x_1` and :math:`x_2` is defined as the matrix product between the Jacobian88# of the model evaluated at :math:`x_1` and the Jacobian of the model evaluated at89# :math:`x_2`:90#91# .. math::92#93# J_{net}(x_1) J_{net}^T(x_2)94#95# In the batched case where :math:`x_1` is a batch of data points and :math:`x_2` is a96# batch of data points, then we want the matrix product between the Jacobians97# of all combinations of data points from :math:`x_1` and :math:`x_2`.98#99# The first method consists of doing just that - computing the two Jacobians,100# and contracting them. Here's how to compute the NTK in the batched case:101102def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):103# Compute J(x1)104jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)105jac1 = jac1.values()106jac1 = [j.flatten(2) for j in jac1]107108# Compute J(x2)109jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)110jac2 = jac2.values()111jac2 = [j.flatten(2) for j in jac2]112113# Compute J(x1) @ J(x2).T114result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])115result = result.sum(0)116return result117118result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)119print(result.shape)120121######################################################################122# In some cases, you may only want the diagonal or the trace of this quantity,123# especially if you know beforehand that the network architecture results in an124# NTK where the non-diagonal elements can be approximated by zero. It's easy to125# adjust the above function to do that:126127def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):128# Compute J(x1)129jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)130jac1 = jac1.values()131jac1 = [j.flatten(2) for j in jac1]132133# Compute J(x2)134jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)135jac2 = jac2.values()136jac2 = [j.flatten(2) for j in jac2]137138# Compute J(x1) @ J(x2).T139einsum_expr = None140if compute == 'full':141einsum_expr = 'Naf,Mbf->NMab'142elif compute == 'trace':143einsum_expr = 'Naf,Maf->NM'144elif compute == 'diagonal':145einsum_expr = 'Naf,Maf->NMa'146else:147assert False148149result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])150result = result.sum(0)151return result152153result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')154print(result.shape)155156######################################################################157# The asymptotic time complexity of this method is :math:`N O [FP]` (time to158# compute the Jacobians) + :math:`N^2 O^2 P` (time to contract the Jacobians),159# where :math:`N` is the batch size of :math:`x_1` and :math:`x_2`, :math:`O`160# is the model's output size, :math:`P` is the total number of parameters, and161# :math:`[FP]` is the cost of a single forward pass through the model. See162# section 3.2 in163# `Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_164# for details.165#166# Compute the NTK: method 2 (NTK-vector products)167# -----------------------------------------------168#169# The next method we will discuss is a way to compute the NTK using NTK-vector170# products.171#172# This method reformulates NTK as a stack of NTK-vector products applied to173# columns of an identity matrix :math:`I_O` of size :math:`O\times O`174# (where :math:`O` is the output size of the model):175#176# .. math::177#178# J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \left[J_{net}(x_1) \left[J_{net}^T(x_2) e_o\right]\right]_{o=1}^{O},179#180# where :math:`e_o\in \mathbb{R}^O` are column vectors of the identity matrix181# :math:`I_O`.182#183# - Let :math:`\textrm{vjp}_o = J_{net}^T(x_2) e_o`. We can use184# a vector-Jacobian product to compute this.185# - Now, consider :math:`J_{net}(x_1) \textrm{vjp}_o`. This is a186# Jacobian-vector product!187# - Finally, we can run the above computation in parallel over all188# columns :math:`e_o` of :math:`I_O` using ``vmap``.189#190# This suggests that we can use a combination of reverse-mode AD (to compute191# the vector-Jacobian product) and forward-mode AD (to compute the192# Jacobian-vector product) to compute the NTK.193#194# Let's code that up:195196def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):197def get_ntk(x1, x2):198def func_x1(params):199return func(params, x1)200201def func_x2(params):202return func(params, x2)203204output, vjp_fn = vjp(func_x1, params)205206def get_ntk_slice(vec):207# This computes ``vec @ J(x2).T``208# `vec` is some unit vector (a single slice of the Identity matrix)209vjps = vjp_fn(vec)210# This computes ``J(X1) @ vjps``211_, jvps = jvp(func_x2, (params,), vjps)212return jvps213214# Here's our identity matrix215basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)216return vmap(get_ntk_slice)(basis)217218# ``get_ntk(x1, x2)`` computes the NTK for a single data point x1, x2219# Since the x1, x2 inputs to ``empirical_ntk_ntk_vps`` are batched,220# we actually wish to compute the NTK between every pair of data points221# between {x1} and {x2}. That's what the ``vmaps`` here do.222result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)223224if compute == 'full':225return result226if compute == 'trace':227return torch.einsum('NMKK->NM', result)228if compute == 'diagonal':229return torch.einsum('NMKK->NMK', result)230231# Disable TensorFloat-32 for convolutions on Ampere+ GPUs to sacrifice performance in favor of accuracy232with torch.backends.cudnn.flags(allow_tf32=False):233result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)234result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)235236assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)237238######################################################################239# Our code for ``empirical_ntk_ntk_vps`` looks like a direct translation from240# the math above! This showcases the power of function transforms: good luck241# trying to write an efficient version of the above by only using242# ``torch.autograd.grad``.243#244# The asymptotic time complexity of this method is :math:`N^2 O [FP]`, where245# :math:`N` is the batch size of :math:`x_1` and :math:`x_2`, :math:`O` is the246# model's output size, and :math:`[FP]` is the cost of a single forward pass247# through the model. Hence this method performs more forward passes through the248# network than method 1, Jacobian contraction (:math:`N^2 O` instead of249# :math:`N O`), but avoids the contraction cost altogether (no :math:`N^2 O^2 P`250# term, where :math:`P` is the total number of model's parameters). Therefore,251# this method is preferable when :math:`O P` is large relative to :math:`[FP]`,252# such as fully-connected (not convolutional) models with many outputs :math:`O`.253# Memory-wise, both methods should be comparable. See section 3.3 in254# `Fast Finite Width Neural Tangent Kernel <https://arxiv.org/abs/2206.08720>`_255# for details.256257258