You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update base for Update on "[WIP] Compute forward grads for saved_mean and saved_var when input requires grad"
We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.
Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`
Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch
[ghstack-poisoned]
0 commit comments