Skip to content

Commit aed7c7e

Browse files
committed
Update on "[WIP] Compute forward grads for saved_mean and saved_var when input requires grad"
We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed. Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm` Issues: - not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch [ghstack-poisoned]
2 parents 8af0a6a + 6ffc0a9 commit aed7c7e

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

torch/autograd/function.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@ def save_for_backward(self, *tensors: torch.Tensor):
2222
incorrect gradients and memory leaks, and enable the application of saved
2323
tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`.
2424
25-
Note that if intermediary tensors (i.e., tensors that are neither input
26-
nor output) are saved for backward, your custom Function may not support
27-
`double backward <https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html>`_.
25+
Note that if intermediary tensors (tensors that are neither input
26+
nor output of :func:`forward`) are saved for backward, your custom Function
27+
may not support double backward.
2828
Custom Functions that do not support double backward should decorate their
29-
:func:`backward` method with `@once_differentiable` so that performing
30-
double backward raises an error. If you'd like to support double backawrd
29+
:func:`backward` method with ``@once_differentiable`` so that performing
30+
double backward raises an error. If you'd like to support double backward
3131
you can either recompute intermediaries based on the inputs during backward
32-
or return the intermediaries as the outputs of the custom Function. See
33-
the tutorial linked above for more details.
32+
or return the intermediaries as the outputs of the custom Function. See the
33+
`double backward tutorial <https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html>`_.
34+
for more details.
3435
3536
In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors`
3637
attribute. Before returning them to the user, a check is made to ensure

0 commit comments

Comments
 (0)