Path: blob/main/recipes_source/debug_mode_tutorial.py
3698 views
# -*- coding: utf-8 -*-12"""3DebugMode: Recording Dispatched Operations and Numerical Debugging4=================================================================56**Authors:** Pian Pawakapan, Shangdi Yu78.. grid:: 2910.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn11:class-card: card-prerequisites1213* How to capture dispatched ops for eager and ``torch.compile`` runs14* How to use tensor hashes and stack traces in DebugMode to pinpoint numerical divergence1516.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites17:class-card: card-prerequisites1819* PyTorch 2.10 or later2021"""2223######################################################################24# Overview25# --------26#27# ``DebugMode`` (:class:`torch.utils._debug_mode.DebugMode`) is a28# ``TorchDispatchMode`` that intercepts PyTorch runtime calls and emits a29# hierarchical log of operations. It is particularly useful when you need to30# understand *what* actually runs, both in eager mode and under ``torch.compile``31# or when you need to pinpoint numerical divergence between two runs.32#33# Key capabilities:34#35# * **Runtime logging** – Records dispatched operations and TorchInductor compiled36# Triton kernels.37# * **Tensor hashing** – Attaches deterministic hashes to inputs/outputs to enable38# diffing runs to locate numerical divergences.39# * **Dispatch hooks** – Allows registration of custom hooks to annotate calls40#41# .. note::42#43# This recipe describes a prototype feature. Prototype features are typically44# at an early stage for feedback and testing and are subject to change.45#4647######################################################################48# Quick start49# -----------50#51# The snippet below captures a small eager workload and prints the debug string:5253from torch._inductor.decomposition import decomps_to_exclude54import torch55from torch.utils._debug_mode import DebugMode5657def run_once():58x = torch.randn(8, 8)59y = torch.randn(8, 8)60return torch.mm(torch.relu(x), y)6162with DebugMode() as debug_mode:63out = run_once()6465print("DebugMode output:")66print(debug_mode.debug_string())676869######################################################################70# Getting more metadata71# -----------72#73# For most investigations, you'll want to enable stack traces, tensor IDs, and tensor hashing.74# These features provide metadata to correlate operations back to model code.75#76# ``DebugMode.log_tensor_hashes`` decorates the log with hashes for every call.77# The ``hash_tensor`` hash function uses ``torch.hash_tensor``, which returns 0 for tensors whose78# elements are all the same. The ``norm`` hash function uses ``norm`` with ``p=1``.79# With both these functions, especially ``norm``, tensor closeness in numerics is related to hash closeness,80# so it's rather interpretable. The default ``hash_fn`` is ``norm``.8182with (83DebugMode(84# record_stack_trace is only supported for eager in pytorch 2.1085record_stack_trace=True,86record_ids=True,87) as debug_mode,88DebugMode.log_tensor_hashes(89hash_fn=["norm"], # this is the default90hash_inputs=True,91),92):93result = run_once()9495print("DebugMode output with more metadata:")96print(97debug_mode.debug_string(show_stack_trace=True)98)99100######################################################################101# Each line follows ``op(args) -> outputs``. When ``record_ids`` is enabled,102# tensors are suffixed with ``$<id>`` and DTensors are labeled ``dt``.103104105######################################################################106# Log Triton kernels107# ------------------108#109# Though Triton kernels are not dispatched, DebugMode has custom logic that logs their inputs and outputs.110#111# Inductor-generated Triton kernels show up with a ``[triton]`` prefix.112# Pre/post hash annotations report buffer hashes around each kernel call, which113# is helpful when isolating incorrect kernels.114def f(x):115return torch.mm(torch.relu(x), x.T)116117x = torch.randn(3, 3, device="cuda")118119with (120DebugMode(record_output=True) as debug_mode,121DebugMode.log_tensor_hashes(122hash_inputs=True,123)124):125a = torch.compile(f)(x)126127print("Triton in DebugMode logs:")128print(debug_mode.debug_string())129130######################################################################131# Numerical debugging with tensor hashes132# --------------------------------------133#134# If you have numerical divergence between modes, you can use DebugMode to find where the135# numerical divergence originates.136# In the example below, you can see that all tensor hashes are the same for eager mode and compiled mode.137# If any hash is different, then that's where the numerical divergence is coming from.138139def run_model(model, data, *, compile_with=None):140if compile_with is not None:141model = torch.compile(model, backend=compile_with)142with DebugMode(record_output=True) as dm, DebugMode.log_tensor_hashes(143hash_inputs=True,144):145dm_out = model(*data)146return dm, dm_out147148class Toy(torch.nn.Module):149def forward(self, x):150return torch.relu(x).mm(x.T)151152inputs = (torch.randn(4, 4),)153dm_eager, _ = run_model(Toy(), inputs)154dm_compiled, _ = run_model(Toy(), inputs, compile_with="aot_eager")155156print("Eager mode:")157print(dm_eager.debug_string())158print("Compiled aot_eager mode:")159print(dm_compiled.debug_string())160161###############################################################################################162# Now let's look at an example where the tensor hashes are different.163# I intentionally wrote a wrong decomposition that decomposes cosine to sin.164# This will cause numerical divergence.165166167from torch._dynamo.backends.common import aot_autograd168from torch._dynamo.backends.debugging import get_nop_func169170def wrong_decomp(x):171return torch.sin(x)172173decomp_table = {}174decomp_table[torch.ops.aten.cos.default] = wrong_decomp175176backend = aot_autograd(177fw_compiler=get_nop_func(),178bw_compiler=get_nop_func(),179decompositions=decomp_table180)181182def f(x):183y = x.relu()184z = torch.cos(x)185return y + z186187x = torch.randn(3, 3)188with DebugMode(record_output=True) as dm_eager, DebugMode.log_tensor_hashes(189hash_inputs=True,190):191f(x)192193with DebugMode(record_output=True) as dm_compiled, DebugMode.log_tensor_hashes(194hash_inputs=True,195):196torch.compile(f, backend=backend)(x)197198print("Eager:")199print(dm_eager.debug_string(show_stack_trace=True))200print()201print("Compiled with wrong decomposition:")202print(dm_compiled.debug_string())203204###############################################################################################205# In the eager log, we have ``aten::cos``, but in the compiled log, we have ``aten::sin``.206# Moreover, the output hash is different between eager and compiled mode.207# Diffing the two logs would show that the first numerical divergence shows up in the ``aten::cos`` call.208209210211212######################################################################213# Custom dispatch hooks214# ---------------------215#216# Hooks allow you to annotate each call with custom metadata such as GPU memory usage. ``log_hook`` returns a mapping217# that is rendered inline with the debug string.218219MB = 1024 * 1024.0220221def memory_hook(func, types, args, kwargs, result):222mem = torch.cuda.memory_allocated() / MB if torch.cuda.is_available() else 0.0223peak = torch.cuda.max_memory_allocated() / MB if torch.cuda.is_available() else 0.0224torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None225return {"mem": f"{mem:.3f} MB", "peak": f"{peak:.3f} MB"}226227with (228DebugMode() as dm,229DebugMode.dispatch_hooks(log_hook=memory_hook),230):231run_once()232233print("DebugMode output with memory usage:")234print(dm.debug_string())235236######################################################################237# Module boundaries238# ----------------------------------239#240# ``record_nn_module=True`` inserts ``[nn.Mod]`` markers that show which241# module executed each set of operations. As of PyTorch 2.10 it only works in eager mode,242# but support for compiled modes is under development.243244class Foo(torch.nn.Module):245def __init__(self):246super().__init__()247self.l1 = torch.nn.Linear(4, 4)248self.l2 = torch.nn.Linear(4, 4)249250def forward(self, x):251return self.l2(self.l1(x))252253class Bar(torch.nn.Module):254def __init__(self):255super().__init__()256self.abc = Foo()257self.xyz = torch.nn.Linear(4, 4)258259def forward(self, x):260return self.xyz(self.abc(x))261262mod = Bar()263inp = torch.randn(4, 4)264with DebugMode(record_nn_module=True, record_output=False) as debug_mode:265_ = mod(inp)266267print("DebugMode output with stack traces and module boundaries:")268print(debug_mode.debug_string(show_stack_trace=True))269270######################################################################271# Conclusion272# ----------273#274# In this tutorial, we saw how DebugMode gives you a lightweight, runtime-only275# view of what PyTorch actually executed, whether you are running eager code or276# compiled graphs. By layering tensor hashing, Triton logging, and custom277# dispatch hooks you can quickly track down numerical differences. This is278# especially helpful in debugging bit-wise equivalence between runs.279280281