Skip to content

Commit 33dfded

Browse files
eellisonpytorchmergebot
authored andcommitted
CUDAGraph Trees - Warn on dealloc (pytorch#97171)
Differential Revision: [D44228370](https://our.internmc.facebook.com/intern/diff/D44228370) Pull Request resolved: pytorch#97171 Approved by: https://github.com/ezyang, https://github.com/jansel
1 parent 24e280d commit 33dfded

File tree

5 files changed

+97
-27
lines changed

5 files changed

+97
-27
lines changed

test/inductor/test_cudagraph_trees.py

+17
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import importlib
66
import sys
77
import unittest
8+
import warnings
89

910
import torch
1011

@@ -100,6 +101,7 @@ def setUp(self):
100101
config.triton.cudagraphs = True
101102
config.triton.cudagraph_trees = True
102103
self.device_idx = torch.rand([0], device="cuda").device.index
104+
warnings.filterwarnings("ignore")
103105

104106
def tearDown(self):
105107
super().tearDown()
@@ -109,6 +111,7 @@ def tearDown(self):
109111
config.triton.cudagraph_trees = self.tapes_enabled
110112
self.assertIsNone(self.get_manager())
111113
self.assertEqual(all_live_block_count(), 0)
114+
warnings.resetwarnings()
112115

113116
def get_manager(self, device_index=None):
114117
return torch._inductor.cudagraph_trees.get_container(
@@ -534,6 +537,20 @@ def foo(args):
534537
test()
535538
self.assertTrue(self.get_manager(device_index=1) is None)
536539

540+
def test_warnings_on_dealloc(self):
541+
@torch.compile()
542+
def foo(x):
543+
return x * x * x
544+
545+
inp = torch.rand([4], device="cuda")
546+
out = foo(inp)
547+
warnings.resetwarnings()
548+
with warnings.catch_warnings(record=True) as w:
549+
foo(inp)
550+
551+
self.assertTrue(len(w) == 1)
552+
self.assertTrue("x * x * x" in str(w[0]))
553+
537554
def test_forward_generation(self):
538555
def foo(x):
539556
return x * x * x

torch/_functorch/aot_autograd.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2768,7 +2768,8 @@ def aot_function(
27682768
larger Aten ops into simpler or core Aten ops.
27692769
inference_compiler (Optional[Callable]): A Python function that accepts an
27702770
Fx graph with Aten ops and input args, and returns a Callable that
2771-
semantically is equivalent to the input Fx graph. Default: None
2771+
semantically is equivalent to the input Fx graph. inference_compiler is invoked
2772+
if no autograd is needed. Default: None
27722773
(when None, it defaults to the :attr:`fw_compiler`)
27732774
Returns:
27742775
Returns a ``Callable`` that retains the eager behavior of the original

torch/_inductor/compile_fx.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,15 @@ def compile_fx_inner(
185185
if aot_mode:
186186
return compiled_fn
187187

188-
output = list(gm.graph.nodes)[-1]
189-
assert len(output.args) == 1
190188
if cudagraphs:
189+
# output args are tuple of first argument
190+
output = list(gm.graph.nodes)[-1]
191+
assert len(output.args) == 1
192+
stack_traces = [
193+
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
194+
for arg in output.args[0]
195+
]
196+
191197
complex_memory_overlap_inputs = any(
192198
complex_memory_overlap(t) for t in example_inputs
193199
)
@@ -204,6 +210,7 @@ def compile_fx_inner(
204210
example_inputs,
205211
static_input_idxs=range(num_fixed),
206212
device_index=next(iter(graph.device_idxs)),
213+
stack_traces=stack_traces,
207214
is_backward=is_backward,
208215
is_inference=is_inference,
209216
)
@@ -279,6 +286,7 @@ def cudagraphify(
279286
static_input_idxs=(),
280287
*,
281288
device_index: int,
289+
stack_traces: List[Optional[str]],
282290
is_backward: bool,
283291
is_inference: bool,
284292
):
@@ -290,6 +298,7 @@ def cudagraphify(
290298
cudagraphify_fn = functools.partial(
291299
new_cudagraphify_impl,
292300
device_index=device_index,
301+
stack_traces=stack_traces,
293302
is_backward=is_backward,
294303
is_inference=is_inference,
295304
)

torch/_inductor/config.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,11 @@ class triton:
194194
# Use cudagraph trees for memory pooling if `cudagraphs` is True
195195
cudagraph_trees = False
196196

197-
debug_cudagraph_trees = True
197+
# assertions not on the fast path, steady state
198+
fast_cudagraph_asserts = True
199+
200+
# assertions on the fast path
201+
slow_cudagraph_asserts = False
198202

199203
# skip warmup for cudagraph trees
200204
skip_cudagraph_warmup = False

torch/_inductor/cudagraph_trees.py

+62-23
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,22 @@ def cudagraphify_impl(
258258
device_index: int,
259259
is_backward: bool,
260260
is_inference: bool,
261+
stack_traces: Optional[StackTraces] = None,
261262
):
262263
manager = get_container(device_index).get_tree_manager()
264+
assert not (is_backward and is_inference)
265+
mode = (
266+
CompilationMode.BACKWARD
267+
if is_backward
268+
else (CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD)
269+
)
270+
263271
return manager.add_function(
264272
model,
265273
inputs,
266274
static_input_idxs,
267-
is_backward,
268-
is_inference,
275+
stack_traces,
276+
mode,
269277
)
270278

271279

@@ -351,6 +359,8 @@ def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]:
351359
# For each node in the path, for each output, is the output alive
352360
PathLiveness = List[List[bool]]
353361

362+
StackTraces = List[Optional[str]]
363+
354364

355365
class CUDAWarmupNode:
356366
"""
@@ -378,6 +388,7 @@ def __init__(
378388
cuda_graphs_pool: Tuple[int, int],
379389
existing_cuda_graph: torch.cuda.Graph,
380390
device_index: int,
391+
stack_traces: Optional[StackTraces],
381392
):
382393
self.wrapped_function = wrapped_function
383394
self.parent = parent
@@ -386,6 +397,7 @@ def __init__(
386397
self.existing_cuda_graph = existing_cuda_graph
387398
self.has_run = False
388399
self.device_index = device_index
400+
self.stack_traces = stack_traces
389401

390402
def run(self, new_inputs):
391403
assert not self.has_run, "Wrapped function should never be run twice"
@@ -403,7 +415,7 @@ def run(self, new_inputs):
403415
):
404416
non_cudagraph_inps.add(new_inputs[i].untyped_storage().data_ptr())
405417

406-
if config.triton.debug_cudagraph_trees:
418+
if config.triton.fast_cudagraph_asserts:
407419
refs = list(self.path_live_weakrefs())
408420
check_memory_pool(self.cuda_graphs_pool, refs)
409421

@@ -425,7 +437,7 @@ def run(self, new_inputs):
425437
]
426438
)
427439

428-
if config.triton.debug_cudagraph_trees:
440+
if config.triton.fast_cudagraph_asserts:
429441
out_refs = self.path_live_weakrefs()
430442
new_storages = [
431443
t for t in out_refs if t.data_ptr() not in non_cudagraph_inps
@@ -436,16 +448,22 @@ def run(self, new_inputs):
436448

437449
def path_live_weakrefs(self) -> Generator[StorageWeakRefWrapper]:
438450
"Returns all live storages weakrefs that created by nodes in this path"
451+
for stor_ref, _ in self.path_live_weakrefs_and_stacktraces():
452+
yield stor_ref
453+
454+
def path_live_weakrefs_and_stacktraces(
455+
self,
456+
) -> Generator[Tuple[StorageWeakRefWrapper, Optional[str]]]:
439457
nodes = []
440458
node = self
441459
while node:
442460
nodes.append(node)
443461
node = node.parent
444462

445463
for node in reversed(nodes):
446-
for output in node.outputs_weakrefs:
464+
for i, output in enumerate(node.outputs_weakrefs):
447465
if is_live(output):
448-
yield output
466+
yield output, (node.stack_traces[i] if node.stack_traces else None)
449467

450468
def all_outputs_are_dead(self):
451469
return not list(self.path_live_weakrefs())
@@ -486,12 +504,14 @@ def __init__(
486504
inputs: List[Tensor],
487505
cuda_graphs_pool: Tuple[int, int],
488506
device_index: int,
507+
stack_traces: Optional[StackTraces],
489508
):
490509
assert isinstance(inputs, (list, tuple))
491510

492511
self.wrapped_function = wrapped_function
493512
self.id = id
494513
self.device = device_index
514+
self.stack_traces = stack_traces
495515

496516
# if this is a root parent will be None. use weakref to prevent reference cycle
497517
self._parent = weakref.ref(parent) if parent is not None else None
@@ -510,6 +530,9 @@ def __init__(
510530
self.path_weakrefs: LevelList[OutputList[Optional[StorageWeakRefWrapper]]] = [
511531
node.outputs_weakrefs for node in self._path_from_root
512532
]
533+
self.path_stacktraces: LevelList[StackTraces] = [
534+
node.stack_traces for node in self._path_from_root
535+
]
513536

514537
# tensors which are outputs of previous graphs in the tree
515538
self.cudagraph_managed_idxs: List[int] = [
@@ -616,7 +639,7 @@ def __init__(
616639
self.checkpointed_caching_state: Optional[AllocatorState] = None
617640

618641
def run(self, new_inputs):
619-
if config.triton.debug_cudagraph_trees:
642+
if config.triton.slow_cudagraph_asserts:
620643
self.debug_check_invariants_before_invocation()
621644

622645
assert len(self.static_input_data_ptrs) == len(new_inputs)
@@ -677,7 +700,7 @@ def all_outputs_are_dead(self):
677700
def _record(self, model, stream, inputs):
678701
"Record the model"
679702

680-
if config.triton.debug_cudagraph_trees:
703+
if config.triton.fast_cudagraph_asserts:
681704
# need to use parent live weakrefs because live_indices isnt set yet
682705
memory = (
683706
[] if self.parent is None else list(self.parent.path_live_weakrefs())
@@ -720,6 +743,13 @@ def _add_first_outputs(self, outputs):
720743
and o.untyped_storage().data_ptr() in self.static_input_storage_ptrs
721744
)
722745

746+
if self.stack_traces is None:
747+
self.stack_traces = [None for _ in range(len(outputs))]
748+
else:
749+
assert len(self.stack_traces) == len(
750+
outputs
751+
), "Wrong number of stack traces passed in"
752+
723753
self._add_replayed_outputs(outputs)
724754
self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs)
725755

@@ -734,7 +764,7 @@ def _add_first_outputs(self, outputs):
734764
self.live_indices_after_graph.append((depth, output_index))
735765

736766
self.debug_check_invariants_after_invocation()
737-
if config.triton.debug_cudagraph_trees:
767+
if config.triton.fast_cudagraph_asserts:
738768
check_memory_pool(self.cuda_graphs_pool, list(self.path_live_weakrefs()))
739769

740770
def _add_replayed_outputs(self, outputs):
@@ -816,7 +846,7 @@ def _get_liveness(
816846
def debug_assert_invariants(
817847
self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex]
818848
):
819-
if not config.triton.debug_cudagraph_trees:
849+
if not config.triton.slow_cudagraph_asserts:
820850
return
821851

822852
for i, node in enumerate(self._path_from_root):
@@ -1066,6 +1096,8 @@ def __init__(self, device_index: int):
10661096
# mapping from function id to wrapped function
10671097
self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {}
10681098

1099+
self.ids_to_stack_traces: Dict[FunctionID, StackTraces] = {}
1100+
10691101
self.warmed_up_functions: Set[FunctionID] = set()
10701102

10711103
with torch.cuda.device(device_index):
@@ -1194,6 +1226,7 @@ def record_function(self, new_inputs, function_id) -> List[Optional[Tensor]]:
11941226
new_inputs,
11951227
self.cuda_graphs_thread_pool,
11961228
self.device_index,
1229+
self.ids_to_stack_traces[function_id],
11971230
)
11981231
if self.current_node is None:
11991232
self.roots[function_id].append(node)
@@ -1220,6 +1253,7 @@ def run_eager(self, new_inputs, function_id: FunctionID):
12201253
self.cuda_graphs_thread_pool,
12211254
self.graph,
12221255
self.device_index,
1256+
self.ids_to_stack_traces[function_id],
12231257
)
12241258
self.current_node = node
12251259
self.path_state = ExecutionState.WARMUP
@@ -1240,22 +1274,15 @@ def add_function(
12401274
model,
12411275
inputs,
12421276
static_input_idxs,
1243-
is_backward,
1244-
is_inference,
1277+
stack_traces,
1278+
mode,
12451279
) -> Callable:
12461280
id = self.new_func_id()
1281+
self.ids_to_stack_traces[id] = stack_traces
12471282
self.ids_to_funcs[id] = WrappedFunction(
12481283
model, remove_unaligned_input_idxs(inputs, static_input_idxs), id
12491284
)
1250-
self.id_to_mode[id] = (
1251-
CompilationMode.BACKWARD
1252-
if is_backward
1253-
else (
1254-
CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD
1255-
)
1256-
)
1257-
1258-
comp_context = torch._functorch.aot_autograd.get_graph_being_compiled()
1285+
self.id_to_mode[id] = mode
12591286
fn = functools.partial(self.run, function_id=id)
12601287

12611288
# container needs to set clean up when fn dies
@@ -1345,9 +1372,21 @@ def try_end_curr_warmup(self):
13451372
def dealloc_current_path_weakrefs(self):
13461373
# TODO: we could also allow the these weak refs to continue to be allocated,
13471374
# but that adds some complications.
1348-
for t in self.current_node.path_live_weakrefs():
1375+
for t, stack_trace in self.current_node.path_live_weakrefs_and_stacktraces():
1376+
# TODO: dont need to test t(), but would need to deduplicate storages
13491377
if t():
13501378
torch._C._free_And_Remove_DeleterFn(t())
1379+
stack_trace = (
1380+
stack_trace.strip()
1381+
if stack_trace
1382+
else "[Could not find stack trace]"
1383+
)
1384+
warnings.warn(
1385+
f"CUDAGraphTrees triggered deallocating tensor output from {stack_trace}. "
1386+
"Subsequent use of this storage may return garbage result. "
1387+
"Outside of torch.compile(), clone the corresponding tensor for safety, or "
1388+
"deallocate the corresponding output no longer in use."
1389+
)
13511390

13521391
def clear_current_node_outputs_and_set_to_none(self):
13531392
self.current_node.clear_path_outputs()
@@ -1377,7 +1416,7 @@ def apply_checkpoint_execution_state_in_allocator(self):
13771416
torch._C._cuda_cudaCachingAllocator_raw_delete(ptr)
13781417

13791418
# Now the live blocks should be exactly equal to the live storages in private pool
1380-
if config.triton.debug_cudagraph_trees:
1419+
if config.triton.fast_cudagraph_asserts:
13811420
check_memory_pool(self.cuda_graphs_thread_pool, live_storages_wrappers)
13821421

13831422
def live_cudagraph_pool_storages_in_curr_execution(

0 commit comments

Comments
 (0)