Skip to content

Commit ca3ac11

Browse files
author
Samantha Andow
authored
Generate n^2 not n^3 inputs for batch and instance norm; small batch norm fix (#951)
* refactor batch norm exhaustive inputs * fix typo in batch rule * fix expand issue, add without cudnn xfail
1 parent 4f25800 commit ca3ac11

File tree

3 files changed

+71
-80
lines changed

3 files changed

+71
-80
lines changed

functorch/csrc/BatchRulesNorm.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ batch_norm_batch_rule(
7575
mean = std::get<1>(result);
7676
rstd = std::get<2>(result);
7777
} else {
78-
bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_mean_bdim);
78+
bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
7979
auto input_ = moveBatchDimToFront(input, input_bdim);
8080
input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size.value());
8181
input_ = reshape_dim_into(0, /*channels dim*/1, input_);
@@ -86,11 +86,17 @@ batch_norm_batch_rule(
8686
running_mean_ = moveBatchDimToFront(running_mean, running_mean_bdim);
8787
running_mean_ = ensure_has_bdim(*running_mean_, running_mean_bdim.has_value(), bdim_size.value());
8888
running_mean_ = reshape_dim_into(0, 0, *running_mean_);
89+
if (training) {
90+
running_mean_ = running_mean_->contiguous();
91+
}
8992
}
9093
if (running_var.defined()) {
9194
running_var_ = moveBatchDimToFront(running_var, running_var_bdim);
9295
running_var_ = ensure_has_bdim(*running_var_, running_var_bdim.has_value(), bdim_size.value());
9396
running_var_ = reshape_dim_into(0, 0, *running_var_);
97+
if (training) {
98+
running_var_ = running_var_->contiguous();
99+
}
94100
}
95101

96102
const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight

test/common_utils.py

Lines changed: 55 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,34 @@ def get_bdim_choices(num_tensors):
7070
assert choices[-1] == (None,) * num_tensors
7171
return tuple(choices[:-1])
7272

73+
# NB: This is O(2 ** num_tensors).
74+
# num_tensors ranges from 1 to 10, with 2-4 being most common.
75+
# Try not to extravagate it if you're modifying it.
76+
def get_bdim_choices_batch_norm(num_tensors, _, running_mean=None, running_var=None, *args):
77+
choices = []
78+
options = (-1, None)
79+
80+
# instance norm turns these into unbatched 0 tensors, so we cannot batch the input if either is not specified
81+
if running_mean == None or running_var == None:
82+
choices.append((None,) + (0,) * (num_tensors - 1))
83+
for choice in itertools.product(options, repeat=num_tensors - 1):
84+
choices.append((None,) + choice)
85+
86+
else:
87+
# running_mean and running_var are specified as tensors. Batch norm doesn't work if the input is batched but
88+
# running_mean/var are unbatched, so this tests all other cases
89+
choices.append((0,) * num_tensors)
90+
for choice in itertools.product(options, repeat=num_tensors):
91+
input_bdim = choice[0]
92+
running_mean_bdim = choice[1]
93+
running_var_bdim = choice[2]
94+
if input_bdim and (not running_mean_bdim or not running_var_bdim):
95+
continue
96+
choices.append(choice)
97+
98+
assert choices[-1] == (None,) * num_tensors
99+
return tuple(choices[:-1])
100+
73101

74102
def add_batch_dim(arg, bdim, batch_size=3):
75103
assert bdim == 0 or bdim == -1
@@ -93,12 +121,7 @@ def construct_in_dims(bdim_choice_for_tensors, is_tensors):
93121
result.append(next(bdim))
94122
return tuple(result)
95123

96-
97-
def get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size=2, *, for_batch_norm=False):
98-
if for_batch_norm:
99-
# TODO: delete this path
100-
return get_exhaustive_batched_inputs_batch_norm(arg_values, kwarg_values, batch_size)
101-
124+
def get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size=2):
102125
flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
103126
is_tensors = [isinstance(a, torch.Tensor) for a in flat_args]
104127
bdim_choices = get_bdim_choices(sum(is_tensors))
@@ -120,87 +143,41 @@ def get_batched_arg(arg, bdim):
120143
yield batched_args, in_dims, kwarg_values
121144

122145

123-
def get_exhaustive_batched_inputs_batch_norm(arg_values, kwarg_values, batch_size=3, bdims=(0, -1)):
124-
for_batch_norm = True
125-
assert bdims == (0,) or bdims == (0, -1)
126-
127-
def add_batch_dim(arg, bdim, batch_size=3):
128-
assert bdim == 0 or bdim == -1
129-
if isinstance(arg, torch.Tensor):
130-
if bdim == 0:
131-
shape = [1] * len(arg.shape)
132-
shape.insert(bdim, batch_size)
133-
return (arg.repeat(shape), bdim)
134-
if bdim == -1:
135-
arg = arg.unsqueeze(-1).expand(*arg.shape, batch_size).contiguous()
136-
return (arg, bdim)
137-
assert False
138-
else:
139-
return (arg, None)
140-
for bdim in bdims:
141-
batch_choices = []
142-
143-
def add_batch_choices(a):
144-
if isinstance(a, torch.Tensor):
145-
batched_val = add_batch_dim(a, bdim, batch_size)
146-
batch_choices.append((batched_val, (a, None)))
147-
else:
148-
batch_choices.append(((a, None),))
149-
150-
flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
151-
if for_batch_norm:
152-
# Batch norm is unique because the running_mean and running_var are updated in place.
153-
# Therefore, they cannot be unbatched if the input is batched. The case where both are
154-
# unbatched is added at the end
155-
if len(flat_args) >= 3:
156-
add_batch_choices(flat_args[0]) # input can be batched or unbatched
157-
batch_choices.append((add_batch_dim(flat_args[1], bdim, batch_size),)) # running_mean must be batched
158-
batch_choices.append((add_batch_dim(flat_args[2], bdim, batch_size),)) # running_var must be batched
159-
orig_flat_args = flat_args
160-
flat_args = orig_flat_args[3:]
161-
else:
162-
# TODO: None defaults in instance norm create empty tensors that are written to and mean that we must
163-
# have unbatched inputs. None in the running mean/running var shouldn't make a tensor
164-
batch_choices.append(((flat_args[0], None),)) # input must be unbatched
165-
if len(flat_args) == 2:
166-
batch_choices.append((add_batch_dim(flat_args[1], bdim, batch_size),))
167-
orig_flat_args = flat_args
168-
flat_args = []
169-
170-
for arg in flat_args:
171-
add_batch_choices(arg)
172-
173-
for batched_values in itertools.product(*batch_choices):
174-
batched_args, in_dims = zip(*batched_values)
175-
176-
if all([i is None for i in in_dims]):
177-
continue
178-
179-
yield pytree.tree_unflatten(batched_args, arg_spec), pytree.tree_unflatten(in_dims, arg_spec), kwarg_values
146+
def get_exhaustive_batched_inputs_batch_norm(arg_values, kwarg_values, batch_size=2):
147+
flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
148+
is_tensors = [isinstance(a, torch.Tensor) for a in flat_args]
149+
num_tensors = sum(is_tensors)
150+
if num_tensors == 1: # if there's only an input, can't batch it since running_mean/var will be seen as unbatched tensors
151+
return
152+
bdim_choices = get_bdim_choices_batch_norm(num_tensors, *arg_values)
180153

181-
if for_batch_norm and len(orig_flat_args) >= 2:
182-
# Adds the case where input, running_mean, and running_var are all unbatched
183-
batch_choices[0] = ((orig_flat_args[0], None),)
184-
batch_choices[1] = ((orig_flat_args[1], None),)
185-
if len(orig_flat_args) >= 3:
186-
batch_choices[2] = ((orig_flat_args[2], None),)
187-
for batched_values in itertools.product(*batch_choices):
188-
batched_args, in_dims = zip(*batched_values)
154+
@memoize
155+
def get_batched_arg(arg, bdim):
156+
assert isinstance(arg, torch.Tensor)
157+
assert bdim is not None
158+
result, _ = add_batch_dim(arg, bdim, batch_size)
159+
return result
189160

190-
if all([i is None for i in in_dims]):
191-
continue
161+
for bdim_choice in bdim_choices:
162+
flat_in_dims = construct_in_dims(bdim_choice, is_tensors)
192163

193-
batched_args_tuple = pytree.tree_unflatten(batched_args, arg_spec)
194-
in_dims_tuple = pytree.tree_unflatten(in_dims, arg_spec)
195-
yield batched_args_tuple, in_dims_tuple, kwarg_values
164+
flat_batched_args = tuple(arg if in_dim is None else get_batched_arg(arg, in_dim)
165+
for arg, in_dim in zip(flat_args, flat_in_dims))
166+
batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec)
167+
in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec)
168+
yield batched_args, in_dims, kwarg_values
196169

197170

198171
def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, opinfo=None, compute_loop_out=True):
199172
out_dim = 0
200173
batch_size = 2
201174
batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm
202-
for_batch_norm = opinfo is not None and opinfo.name in batch_norm_fns
203-
generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size, for_batch_norm=for_batch_norm)
175+
176+
if opinfo is not None and opinfo.name in batch_norm_fns:
177+
generator = get_exhaustive_batched_inputs_batch_norm(arg_values, kwarg_values, batch_size)
178+
else:
179+
generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size)
180+
204181
for batched_args, in_dims, kwarg_values in generator:
205182
if compute_loop_out:
206183
loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values)

test/test_ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from common_utils import (
2020
get_fallback_and_vmap_exhaustive,
2121
get_exhaustive_batched_inputs,
22+
get_exhaustive_batched_inputs_batch_norm,
2223
xfail,
2324
skip,
2425
skipOps,
@@ -663,6 +664,10 @@ def test_vmapvjp(self, device, dtype, op):
663664

664665
xfail('put'), # calls put_ during vmap with only vmaps over other, not self
665666
xfail('nn.functional.prelu'), # Call Tensor.as_strided
667+
668+
# erroring because running_mean and running_var aren't differentiable
669+
xfail('nn.functional.batch_norm'),
670+
xfail('nn.functional.batch_norm', 'without_cudnn'),
666671
}
667672

668673
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
@@ -964,7 +969,10 @@ def test_vjpvmap(self, device, dtype, op):
964969
for sample in samples:
965970
args = [sample.input] + list(sample.args)
966971
kwargs = sample.kwargs
967-
generator = get_exhaustive_batched_inputs(args, kwargs, for_batch_norm=is_batch_norm)
972+
if is_batch_norm:
973+
generator = get_exhaustive_batched_inputs_batch_norm(args, kwargs)
974+
else:
975+
generator = get_exhaustive_batched_inputs(args, kwargs)
968976

969977
for batched_args, in_dims, kwargs in generator:
970978
vmapped_op = vmap(op, in_dims)

0 commit comments

Comments
 (0)