From b016e1874923722df1ba51ac0a53c967119940d4 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 10 May 2022 00:09:41 +0000 Subject: [PATCH] Remove reference cycle --- torchdynamo/convert_frame.py | 1 + torchdynamo/output_graph.py | 9 +++++++++ torchdynamo/variables/builder.py | 3 +++ 3 files changed, 13 insertions(+) diff --git a/torchdynamo/convert_frame.py b/torchdynamo/convert_frame.py index ec5a685cb3..1ce0425ae3 100644 --- a/torchdynamo/convert_frame.py +++ b/torchdynamo/convert_frame.py @@ -149,6 +149,7 @@ def transform(instructions, code_options): ) tracer.run() output = tracer.output + output.cleanup() assert output.output_instructions instructions[:] = output.output_instructions code_options.update(output.code_options) diff --git a/torchdynamo/output_graph.py b/torchdynamo/output_graph.py index e86ea0c510..b53120189d 100644 --- a/torchdynamo/output_graph.py +++ b/torchdynamo/output_graph.py @@ -365,3 +365,12 @@ def add_output_instructions(self, prefix: List[Instruction]): def install_global(self, name, value): self.cleanups.append(CleanupHook.create(self.root_globals, name, value)) + + def cleanup(self): + # There is a reference cycle between tracer and OutputGraph, causing + # some of the tensor objects to be held alive for longer than necessary. + self.root_tx = None + + # Cleanup graphargs + for graph_arg in self.graphargs: + graph_arg.erase() diff --git a/torchdynamo/variables/builder.py b/torchdynamo/variables/builder.py index 7a19eccfc9..469792a226 100644 --- a/torchdynamo/variables/builder.py +++ b/torchdynamo/variables/builder.py @@ -70,6 +70,9 @@ def get_examples(self): def __len__(self): return 1 + def erase(self): + self.example = None + class VariableBuilder: """Wrap a python value in a VariableTracker() instance"""