Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/debug_mode_tutorial.py
3698 views
1
# -*- coding: utf-8 -*-
2
3
"""
4
DebugMode: Recording Dispatched Operations and Numerical Debugging
5
=================================================================
6
7
**Authors:** Pian Pawakapan, Shangdi Yu
8
9
.. grid:: 2
10
11
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
12
:class-card: card-prerequisites
13
14
* How to capture dispatched ops for eager and ``torch.compile`` runs
15
* How to use tensor hashes and stack traces in DebugMode to pinpoint numerical divergence
16
17
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
18
:class-card: card-prerequisites
19
20
* PyTorch 2.10 or later
21
22
"""
23
24
######################################################################
25
# Overview
26
# --------
27
#
28
# ``DebugMode`` (:class:`torch.utils._debug_mode.DebugMode`) is a
29
# ``TorchDispatchMode`` that intercepts PyTorch runtime calls and emits a
30
# hierarchical log of operations. It is particularly useful when you need to
31
# understand *what* actually runs, both in eager mode and under ``torch.compile``
32
# or when you need to pinpoint numerical divergence between two runs.
33
#
34
# Key capabilities:
35
#
36
# * **Runtime logging** – Records dispatched operations and TorchInductor compiled
37
# Triton kernels.
38
# * **Tensor hashing** – Attaches deterministic hashes to inputs/outputs to enable
39
# diffing runs to locate numerical divergences.
40
# * **Dispatch hooks** – Allows registration of custom hooks to annotate calls
41
#
42
# .. note::
43
#
44
# This recipe describes a prototype feature. Prototype features are typically
45
# at an early stage for feedback and testing and are subject to change.
46
#
47
48
######################################################################
49
# Quick start
50
# -----------
51
#
52
# The snippet below captures a small eager workload and prints the debug string:
53
54
from torch._inductor.decomposition import decomps_to_exclude
55
import torch
56
from torch.utils._debug_mode import DebugMode
57
58
def run_once():
59
x = torch.randn(8, 8)
60
y = torch.randn(8, 8)
61
return torch.mm(torch.relu(x), y)
62
63
with DebugMode() as debug_mode:
64
out = run_once()
65
66
print("DebugMode output:")
67
print(debug_mode.debug_string())
68
69
70
######################################################################
71
# Getting more metadata
72
# -----------
73
#
74
# For most investigations, you'll want to enable stack traces, tensor IDs, and tensor hashing.
75
# These features provide metadata to correlate operations back to model code.
76
#
77
# ``DebugMode.log_tensor_hashes`` decorates the log with hashes for every call.
78
# The ``hash_tensor`` hash function uses ``torch.hash_tensor``, which returns 0 for tensors whose
79
# elements are all the same. The ``norm`` hash function uses ``norm`` with ``p=1``.
80
# With both these functions, especially ``norm``, tensor closeness in numerics is related to hash closeness,
81
# so it's rather interpretable. The default ``hash_fn`` is ``norm``.
82
83
with (
84
DebugMode(
85
# record_stack_trace is only supported for eager in pytorch 2.10
86
record_stack_trace=True,
87
record_ids=True,
88
) as debug_mode,
89
DebugMode.log_tensor_hashes(
90
hash_fn=["norm"], # this is the default
91
hash_inputs=True,
92
),
93
):
94
result = run_once()
95
96
print("DebugMode output with more metadata:")
97
print(
98
debug_mode.debug_string(show_stack_trace=True)
99
)
100
101
######################################################################
102
# Each line follows ``op(args) -> outputs``. When ``record_ids`` is enabled,
103
# tensors are suffixed with ``$<id>`` and DTensors are labeled ``dt``.
104
105
106
######################################################################
107
# Log Triton kernels
108
# ------------------
109
#
110
# Though Triton kernels are not dispatched, DebugMode has custom logic that logs their inputs and outputs.
111
#
112
# Inductor-generated Triton kernels show up with a ``[triton]`` prefix.
113
# Pre/post hash annotations report buffer hashes around each kernel call, which
114
# is helpful when isolating incorrect kernels.
115
def f(x):
116
return torch.mm(torch.relu(x), x.T)
117
118
x = torch.randn(3, 3, device="cuda")
119
120
with (
121
DebugMode(record_output=True) as debug_mode,
122
DebugMode.log_tensor_hashes(
123
hash_inputs=True,
124
)
125
):
126
a = torch.compile(f)(x)
127
128
print("Triton in DebugMode logs:")
129
print(debug_mode.debug_string())
130
131
######################################################################
132
# Numerical debugging with tensor hashes
133
# --------------------------------------
134
#
135
# If you have numerical divergence between modes, you can use DebugMode to find where the
136
# numerical divergence originates.
137
# In the example below, you can see that all tensor hashes are the same for eager mode and compiled mode.
138
# If any hash is different, then that's where the numerical divergence is coming from.
139
140
def run_model(model, data, *, compile_with=None):
141
if compile_with is not None:
142
model = torch.compile(model, backend=compile_with)
143
with DebugMode(record_output=True) as dm, DebugMode.log_tensor_hashes(
144
hash_inputs=True,
145
):
146
dm_out = model(*data)
147
return dm, dm_out
148
149
class Toy(torch.nn.Module):
150
def forward(self, x):
151
return torch.relu(x).mm(x.T)
152
153
inputs = (torch.randn(4, 4),)
154
dm_eager, _ = run_model(Toy(), inputs)
155
dm_compiled, _ = run_model(Toy(), inputs, compile_with="aot_eager")
156
157
print("Eager mode:")
158
print(dm_eager.debug_string())
159
print("Compiled aot_eager mode:")
160
print(dm_compiled.debug_string())
161
162
###############################################################################################
163
# Now let's look at an example where the tensor hashes are different.
164
# I intentionally wrote a wrong decomposition that decomposes cosine to sin.
165
# This will cause numerical divergence.
166
167
168
from torch._dynamo.backends.common import aot_autograd
169
from torch._dynamo.backends.debugging import get_nop_func
170
171
def wrong_decomp(x):
172
return torch.sin(x)
173
174
decomp_table = {}
175
decomp_table[torch.ops.aten.cos.default] = wrong_decomp
176
177
backend = aot_autograd(
178
fw_compiler=get_nop_func(),
179
bw_compiler=get_nop_func(),
180
decompositions=decomp_table
181
)
182
183
def f(x):
184
y = x.relu()
185
z = torch.cos(x)
186
return y + z
187
188
x = torch.randn(3, 3)
189
with DebugMode(record_output=True) as dm_eager, DebugMode.log_tensor_hashes(
190
hash_inputs=True,
191
):
192
f(x)
193
194
with DebugMode(record_output=True) as dm_compiled, DebugMode.log_tensor_hashes(
195
hash_inputs=True,
196
):
197
torch.compile(f, backend=backend)(x)
198
199
print("Eager:")
200
print(dm_eager.debug_string(show_stack_trace=True))
201
print()
202
print("Compiled with wrong decomposition:")
203
print(dm_compiled.debug_string())
204
205
###############################################################################################
206
# In the eager log, we have ``aten::cos``, but in the compiled log, we have ``aten::sin``.
207
# Moreover, the output hash is different between eager and compiled mode.
208
# Diffing the two logs would show that the first numerical divergence shows up in the ``aten::cos`` call.
209
210
211
212
213
######################################################################
214
# Custom dispatch hooks
215
# ---------------------
216
#
217
# Hooks allow you to annotate each call with custom metadata such as GPU memory usage. ``log_hook`` returns a mapping
218
# that is rendered inline with the debug string.
219
220
MB = 1024 * 1024.0
221
222
def memory_hook(func, types, args, kwargs, result):
223
mem = torch.cuda.memory_allocated() / MB if torch.cuda.is_available() else 0.0
224
peak = torch.cuda.max_memory_allocated() / MB if torch.cuda.is_available() else 0.0
225
torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None
226
return {"mem": f"{mem:.3f} MB", "peak": f"{peak:.3f} MB"}
227
228
with (
229
DebugMode() as dm,
230
DebugMode.dispatch_hooks(log_hook=memory_hook),
231
):
232
run_once()
233
234
print("DebugMode output with memory usage:")
235
print(dm.debug_string())
236
237
######################################################################
238
# Module boundaries
239
# ----------------------------------
240
#
241
# ``record_nn_module=True`` inserts ``[nn.Mod]`` markers that show which
242
# module executed each set of operations. As of PyTorch 2.10 it only works in eager mode,
243
# but support for compiled modes is under development.
244
245
class Foo(torch.nn.Module):
246
def __init__(self):
247
super().__init__()
248
self.l1 = torch.nn.Linear(4, 4)
249
self.l2 = torch.nn.Linear(4, 4)
250
251
def forward(self, x):
252
return self.l2(self.l1(x))
253
254
class Bar(torch.nn.Module):
255
def __init__(self):
256
super().__init__()
257
self.abc = Foo()
258
self.xyz = torch.nn.Linear(4, 4)
259
260
def forward(self, x):
261
return self.xyz(self.abc(x))
262
263
mod = Bar()
264
inp = torch.randn(4, 4)
265
with DebugMode(record_nn_module=True, record_output=False) as debug_mode:
266
_ = mod(inp)
267
268
print("DebugMode output with stack traces and module boundaries:")
269
print(debug_mode.debug_string(show_stack_trace=True))
270
271
######################################################################
272
# Conclusion
273
# ----------
274
#
275
# In this tutorial, we saw how DebugMode gives you a lightweight, runtime-only
276
# view of what PyTorch actually executed, whether you are running eager code or
277
# compiled graphs. By layering tensor hashing, Triton logging, and custom
278
# dispatch hooks you can quickly track down numerical differences. This is
279
# especially helpful in debugging bit-wise equivalence between runs.
280
281