From ea5c525249e59090e97a154f29641a9754306001 Mon Sep 17 00:00:00 2001 From: Horace He Date: Thu, 5 May 2022 02:00:10 +0000 Subject: [PATCH 1/2] added list clearing for backwards --- functorch/_src/aot_autograd.py | 30 ++++++++++++++++++++++++++++-- functorch/_src/compilers.py | 1 + 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 36306dbc8..c2be561b9 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -116,6 +116,26 @@ def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None): def new_full(inp, size, value, dtype=None, layout=None, device=None, pin_memory=None): return torch.full(size, value, dtype=inp.dtype, device=inp.device) +import torch.fx as fx +import typing +class ListCodeGen(fx.CodeGen): + def gen_fn_def(self, free_vars, maybe_return_annotation): + lst_unpack = f""" +def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}: + {', '.join(free_vars)} = args_list + args_list.clear() + """ + return lst_unpack + + def additional_globals(self): + return [('List', typing.List)] + + def process_inputs(self, *inputs): + assert(len(inputs) == 1) + return inputs[0] + +def get_memory(): + print(torch.cuda.max_memory_allocated()/1e9) def create_aot_autograd_function( flat_fn, fw_compiler, bw_compiler, partition_fn, decompositions, grad_state @@ -168,18 +188,24 @@ def forward(ctx, *flat_tensor_args): fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] + bw_module.graph.set_codegen(ListCodeGen()) + bw_module.recompile() compiled_bw = bw_compiler(bw_module, bw_args) else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) + # No way of clearing ctx.saved_tensors right now afaik ctx.save_for_backward(*fw_outs[num_outs:]) + # ctx.saved_values = fw_outs[num_outs:] return tuple(fw_outs[0:num_outs]) @staticmethod @disable_torchdynamo def backward(ctx, *flat_args): contiguous_args = [t.contiguous() for t in flat_args] - # contiguous_args = [t for t in flat_args] - out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) + # import pdb; pdb.set_trace() + flat_args = list(ctx.saved_values) + list(contiguous_args) + ctx.saved_values = None + out = normalize_as_list(compiled_bw(flat_args)) return tuple(out) return CompiledFunction diff --git a/functorch/_src/compilers.py b/functorch/_src/compilers.py index 4ce5fd562..5cb2ef68f 100644 --- a/functorch/_src/compilers.py +++ b/functorch/_src/compilers.py @@ -33,6 +33,7 @@ def ts_compile(fx_g: fx.GraphModule, _) -> Callable: Returns: Torch scripted model. """ + return fx_g for node in fx_g.graph.nodes: if node.target in (torch.ops.aten.new_zeros, torch.ops.aten.new_empty): if node.args[1] == []: From d2bc05564c129f2a05e0066b7a61c09b3ee323d4 Mon Sep 17 00:00:00 2001 From: Horace He Date: Thu, 5 May 2022 06:27:42 +0000 Subject: [PATCH 2/2] fix some stuff --- functorch/_src/aot_autograd.py | 4 +--- functorch/_src/compilers.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index c2be561b9..66f1fc794 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -194,15 +194,13 @@ def forward(ctx, *flat_tensor_args): else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) # No way of clearing ctx.saved_tensors right now afaik - ctx.save_for_backward(*fw_outs[num_outs:]) - # ctx.saved_values = fw_outs[num_outs:] + ctx.saved_values = fw_outs[num_outs:] return tuple(fw_outs[0:num_outs]) @staticmethod @disable_torchdynamo def backward(ctx, *flat_args): contiguous_args = [t.contiguous() for t in flat_args] - # import pdb; pdb.set_trace() flat_args = list(ctx.saved_values) + list(contiguous_args) ctx.saved_values = None out = normalize_as_list(compiled_bw(flat_args)) diff --git a/functorch/_src/compilers.py b/functorch/_src/compilers.py index 5cb2ef68f..4ce5fd562 100644 --- a/functorch/_src/compilers.py +++ b/functorch/_src/compilers.py @@ -33,7 +33,6 @@ def ts_compile(fx_g: fx.GraphModule, _) -> Callable: Returns: Torch scripted model. """ - return fx_g for node in fx_g.graph.nodes: if node.target in (torch.ops.aten.new_zeros, torch.ops.aten.new_empty): if node.args[1] == []: