diff --git a/torchdynamo/convert_frame.py b/torchdynamo/convert_frame.py index ec5a685cb3..1ce0425ae3 100644 --- a/torchdynamo/convert_frame.py +++ b/torchdynamo/convert_frame.py @@ -149,6 +149,7 @@ def transform(instructions, code_options): ) tracer.run() output = tracer.output + output.cleanup() assert output.output_instructions instructions[:] = output.output_instructions code_options.update(output.code_options) diff --git a/torchdynamo/output_graph.py b/torchdynamo/output_graph.py index e86ea0c510..b53120189d 100644 --- a/torchdynamo/output_graph.py +++ b/torchdynamo/output_graph.py @@ -365,3 +365,12 @@ def add_output_instructions(self, prefix: List[Instruction]): def install_global(self, name, value): self.cleanups.append(CleanupHook.create(self.root_globals, name, value)) + + def cleanup(self): + # There is a reference cycle between tracer and OutputGraph, causing + # some of the tensor objects to be held alive for longer than necessary. + self.root_tx = None + + # Cleanup graphargs + for graph_arg in self.graphargs: + graph_arg.erase() diff --git a/torchdynamo/variables/builder.py b/torchdynamo/variables/builder.py index 7a19eccfc9..469792a226 100644 --- a/torchdynamo/variables/builder.py +++ b/torchdynamo/variables/builder.py @@ -70,6 +70,9 @@ def get_examples(self): def __len__(self): return 1 + def erase(self): + self.example = None + class VariableBuilder: """Wrap a python value in a VariableTracker() instance"""