@@ -258,14 +258,22 @@ def cudagraphify_impl(
258
258
device_index : int ,
259
259
is_backward : bool ,
260
260
is_inference : bool ,
261
+ stack_traces : Optional [StackTraces ] = None ,
261
262
):
262
263
manager = get_container (device_index ).get_tree_manager ()
264
+ assert not (is_backward and is_inference )
265
+ mode = (
266
+ CompilationMode .BACKWARD
267
+ if is_backward
268
+ else (CompilationMode .INFERENCE if is_inference else CompilationMode .FORWARD )
269
+ )
270
+
263
271
return manager .add_function (
264
272
model ,
265
273
inputs ,
266
274
static_input_idxs ,
267
- is_backward ,
268
- is_inference ,
275
+ stack_traces ,
276
+ mode ,
269
277
)
270
278
271
279
@@ -351,6 +359,8 @@ def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]:
351
359
# For each node in the path, for each output, is the output alive
352
360
PathLiveness = List [List [bool ]]
353
361
362
+ StackTraces = List [Optional [str ]]
363
+
354
364
355
365
class CUDAWarmupNode :
356
366
"""
@@ -378,6 +388,7 @@ def __init__(
378
388
cuda_graphs_pool : Tuple [int , int ],
379
389
existing_cuda_graph : torch .cuda .Graph ,
380
390
device_index : int ,
391
+ stack_traces : Optional [StackTraces ],
381
392
):
382
393
self .wrapped_function = wrapped_function
383
394
self .parent = parent
@@ -386,6 +397,7 @@ def __init__(
386
397
self .existing_cuda_graph = existing_cuda_graph
387
398
self .has_run = False
388
399
self .device_index = device_index
400
+ self .stack_traces = stack_traces
389
401
390
402
def run (self , new_inputs ):
391
403
assert not self .has_run , "Wrapped function should never be run twice"
@@ -403,7 +415,7 @@ def run(self, new_inputs):
403
415
):
404
416
non_cudagraph_inps .add (new_inputs [i ].untyped_storage ().data_ptr ())
405
417
406
- if config .triton .debug_cudagraph_trees :
418
+ if config .triton .fast_cudagraph_asserts :
407
419
refs = list (self .path_live_weakrefs ())
408
420
check_memory_pool (self .cuda_graphs_pool , refs )
409
421
@@ -425,7 +437,7 @@ def run(self, new_inputs):
425
437
]
426
438
)
427
439
428
- if config .triton .debug_cudagraph_trees :
440
+ if config .triton .fast_cudagraph_asserts :
429
441
out_refs = self .path_live_weakrefs ()
430
442
new_storages = [
431
443
t for t in out_refs if t .data_ptr () not in non_cudagraph_inps
@@ -436,16 +448,22 @@ def run(self, new_inputs):
436
448
437
449
def path_live_weakrefs (self ) -> Generator [StorageWeakRefWrapper ]:
438
450
"Returns all live storages weakrefs that created by nodes in this path"
451
+ for stor_ref , _ in self .path_live_weakrefs_and_stacktraces ():
452
+ yield stor_ref
453
+
454
+ def path_live_weakrefs_and_stacktraces (
455
+ self ,
456
+ ) -> Generator [Tuple [StorageWeakRefWrapper , Optional [str ]]]:
439
457
nodes = []
440
458
node = self
441
459
while node :
442
460
nodes .append (node )
443
461
node = node .parent
444
462
445
463
for node in reversed (nodes ):
446
- for output in node .outputs_weakrefs :
464
+ for i , output in enumerate ( node .outputs_weakrefs ) :
447
465
if is_live (output ):
448
- yield output
466
+ yield output , ( node . stack_traces [ i ] if node . stack_traces else None )
449
467
450
468
def all_outputs_are_dead (self ):
451
469
return not list (self .path_live_weakrefs ())
@@ -486,12 +504,14 @@ def __init__(
486
504
inputs : List [Tensor ],
487
505
cuda_graphs_pool : Tuple [int , int ],
488
506
device_index : int ,
507
+ stack_traces : Optional [StackTraces ],
489
508
):
490
509
assert isinstance (inputs , (list , tuple ))
491
510
492
511
self .wrapped_function = wrapped_function
493
512
self .id = id
494
513
self .device = device_index
514
+ self .stack_traces = stack_traces
495
515
496
516
# if this is a root parent will be None. use weakref to prevent reference cycle
497
517
self ._parent = weakref .ref (parent ) if parent is not None else None
@@ -510,6 +530,9 @@ def __init__(
510
530
self .path_weakrefs : LevelList [OutputList [Optional [StorageWeakRefWrapper ]]] = [
511
531
node .outputs_weakrefs for node in self ._path_from_root
512
532
]
533
+ self .path_stacktraces : LevelList [StackTraces ] = [
534
+ node .stack_traces for node in self ._path_from_root
535
+ ]
513
536
514
537
# tensors which are outputs of previous graphs in the tree
515
538
self .cudagraph_managed_idxs : List [int ] = [
@@ -616,7 +639,7 @@ def __init__(
616
639
self .checkpointed_caching_state : Optional [AllocatorState ] = None
617
640
618
641
def run (self , new_inputs ):
619
- if config .triton .debug_cudagraph_trees :
642
+ if config .triton .slow_cudagraph_asserts :
620
643
self .debug_check_invariants_before_invocation ()
621
644
622
645
assert len (self .static_input_data_ptrs ) == len (new_inputs )
@@ -677,7 +700,7 @@ def all_outputs_are_dead(self):
677
700
def _record (self , model , stream , inputs ):
678
701
"Record the model"
679
702
680
- if config .triton .debug_cudagraph_trees :
703
+ if config .triton .fast_cudagraph_asserts :
681
704
# need to use parent live weakrefs because live_indices isnt set yet
682
705
memory = (
683
706
[] if self .parent is None else list (self .parent .path_live_weakrefs ())
@@ -720,6 +743,13 @@ def _add_first_outputs(self, outputs):
720
743
and o .untyped_storage ().data_ptr () in self .static_input_storage_ptrs
721
744
)
722
745
746
+ if self .stack_traces is None :
747
+ self .stack_traces = [None for _ in range (len (outputs ))]
748
+ else :
749
+ assert len (self .stack_traces ) == len (
750
+ outputs
751
+ ), "Wrong number of stack traces passed in"
752
+
723
753
self ._add_replayed_outputs (outputs )
724
754
self .recorded_liveness_after_graph = self ._get_liveness (self .path_weakrefs )
725
755
@@ -734,7 +764,7 @@ def _add_first_outputs(self, outputs):
734
764
self .live_indices_after_graph .append ((depth , output_index ))
735
765
736
766
self .debug_check_invariants_after_invocation ()
737
- if config .triton .debug_cudagraph_trees :
767
+ if config .triton .fast_cudagraph_asserts :
738
768
check_memory_pool (self .cuda_graphs_pool , list (self .path_live_weakrefs ()))
739
769
740
770
def _add_replayed_outputs (self , outputs ):
@@ -816,7 +846,7 @@ def _get_liveness(
816
846
def debug_assert_invariants (
817
847
self , expected_liveness : List [List [bool ]], newly_dead : List [PathOutputIndex ]
818
848
):
819
- if not config .triton .debug_cudagraph_trees :
849
+ if not config .triton .slow_cudagraph_asserts :
820
850
return
821
851
822
852
for i , node in enumerate (self ._path_from_root ):
@@ -1066,6 +1096,8 @@ def __init__(self, device_index: int):
1066
1096
# mapping from function id to wrapped function
1067
1097
self .ids_to_funcs : Dict [FunctionID , WrappedFunction ] = {}
1068
1098
1099
+ self .ids_to_stack_traces : Dict [FunctionID , StackTraces ] = {}
1100
+
1069
1101
self .warmed_up_functions : Set [FunctionID ] = set ()
1070
1102
1071
1103
with torch .cuda .device (device_index ):
@@ -1194,6 +1226,7 @@ def record_function(self, new_inputs, function_id) -> List[Optional[Tensor]]:
1194
1226
new_inputs ,
1195
1227
self .cuda_graphs_thread_pool ,
1196
1228
self .device_index ,
1229
+ self .ids_to_stack_traces [function_id ],
1197
1230
)
1198
1231
if self .current_node is None :
1199
1232
self .roots [function_id ].append (node )
@@ -1220,6 +1253,7 @@ def run_eager(self, new_inputs, function_id: FunctionID):
1220
1253
self .cuda_graphs_thread_pool ,
1221
1254
self .graph ,
1222
1255
self .device_index ,
1256
+ self .ids_to_stack_traces [function_id ],
1223
1257
)
1224
1258
self .current_node = node
1225
1259
self .path_state = ExecutionState .WARMUP
@@ -1240,22 +1274,15 @@ def add_function(
1240
1274
model ,
1241
1275
inputs ,
1242
1276
static_input_idxs ,
1243
- is_backward ,
1244
- is_inference ,
1277
+ stack_traces ,
1278
+ mode ,
1245
1279
) -> Callable :
1246
1280
id = self .new_func_id ()
1281
+ self .ids_to_stack_traces [id ] = stack_traces
1247
1282
self .ids_to_funcs [id ] = WrappedFunction (
1248
1283
model , remove_unaligned_input_idxs (inputs , static_input_idxs ), id
1249
1284
)
1250
- self .id_to_mode [id ] = (
1251
- CompilationMode .BACKWARD
1252
- if is_backward
1253
- else (
1254
- CompilationMode .INFERENCE if is_inference else CompilationMode .FORWARD
1255
- )
1256
- )
1257
-
1258
- comp_context = torch ._functorch .aot_autograd .get_graph_being_compiled ()
1285
+ self .id_to_mode [id ] = mode
1259
1286
fn = functools .partial (self .run , function_id = id )
1260
1287
1261
1288
# container needs to set clean up when fn dies
@@ -1345,9 +1372,21 @@ def try_end_curr_warmup(self):
1345
1372
def dealloc_current_path_weakrefs (self ):
1346
1373
# TODO: we could also allow the these weak refs to continue to be allocated,
1347
1374
# but that adds some complications.
1348
- for t in self .current_node .path_live_weakrefs ():
1375
+ for t , stack_trace in self .current_node .path_live_weakrefs_and_stacktraces ():
1376
+ # TODO: dont need to test t(), but would need to deduplicate storages
1349
1377
if t ():
1350
1378
torch ._C ._free_And_Remove_DeleterFn (t ())
1379
+ stack_trace = (
1380
+ stack_trace .strip ()
1381
+ if stack_trace
1382
+ else "[Could not find stack trace]"
1383
+ )
1384
+ warnings .warn (
1385
+ f"CUDAGraphTrees triggered deallocating tensor output from { stack_trace } . "
1386
+ "Subsequent use of this storage may return garbage result. "
1387
+ "Outside of torch.compile(), clone the corresponding tensor for safety, or "
1388
+ "deallocate the corresponding output no longer in use."
1389
+ )
1351
1390
1352
1391
def clear_current_node_outputs_and_set_to_none (self ):
1353
1392
self .current_node .clear_path_outputs ()
@@ -1377,7 +1416,7 @@ def apply_checkpoint_execution_state_in_allocator(self):
1377
1416
torch ._C ._cuda_cudaCachingAllocator_raw_delete (ptr )
1378
1417
1379
1418
# Now the live blocks should be exactly equal to the live storages in private pool
1380
- if config .triton .debug_cudagraph_trees :
1419
+ if config .triton .fast_cudagraph_asserts :
1381
1420
check_memory_pool (self .cuda_graphs_thread_pool , live_storages_wrappers )
1382
1421
1383
1422
def live_cudagraph_pool_storages_in_curr_execution (
0 commit comments