@@ -70,6 +70,34 @@ def get_bdim_choices(num_tensors):
70
70
assert choices [- 1 ] == (None ,) * num_tensors
71
71
return tuple (choices [:- 1 ])
72
72
73
+ # NB: This is O(2 ** num_tensors).
74
+ # num_tensors ranges from 1 to 10, with 2-4 being most common.
75
+ # Try not to extravagate it if you're modifying it.
76
+ def get_bdim_choices_batch_norm (num_tensors , _ , running_mean = None , running_var = None , * args ):
77
+ choices = []
78
+ options = (- 1 , None )
79
+
80
+ # instance norm turns these into unbatched 0 tensors, so we cannot batch the input if either is not specified
81
+ if running_mean == None or running_var == None :
82
+ choices .append ((None ,) + (0 ,) * (num_tensors - 1 ))
83
+ for choice in itertools .product (options , repeat = num_tensors - 1 ):
84
+ choices .append ((None ,) + choice )
85
+
86
+ else :
87
+ # running_mean and running_var are specified as tensors. Batch norm doesn't work if the input is batched but
88
+ # running_mean/var are unbatched, so this tests all other cases
89
+ choices .append ((0 ,) * num_tensors )
90
+ for choice in itertools .product (options , repeat = num_tensors ):
91
+ input_bdim = choice [0 ]
92
+ running_mean_bdim = choice [1 ]
93
+ running_var_bdim = choice [2 ]
94
+ if input_bdim and (not running_mean_bdim or not running_var_bdim ):
95
+ continue
96
+ choices .append (choice )
97
+
98
+ assert choices [- 1 ] == (None ,) * num_tensors
99
+ return tuple (choices [:- 1 ])
100
+
73
101
74
102
def add_batch_dim (arg , bdim , batch_size = 3 ):
75
103
assert bdim == 0 or bdim == - 1
@@ -93,12 +121,7 @@ def construct_in_dims(bdim_choice_for_tensors, is_tensors):
93
121
result .append (next (bdim ))
94
122
return tuple (result )
95
123
96
-
97
- def get_exhaustive_batched_inputs (arg_values , kwarg_values , batch_size = 2 , * , for_batch_norm = False ):
98
- if for_batch_norm :
99
- # TODO: delete this path
100
- return get_exhaustive_batched_inputs_batch_norm (arg_values , kwarg_values , batch_size )
101
-
124
+ def get_exhaustive_batched_inputs (arg_values , kwarg_values , batch_size = 2 ):
102
125
flat_args , arg_spec = pytree .tree_flatten (tuple (arg_values ))
103
126
is_tensors = [isinstance (a , torch .Tensor ) for a in flat_args ]
104
127
bdim_choices = get_bdim_choices (sum (is_tensors ))
@@ -120,87 +143,41 @@ def get_batched_arg(arg, bdim):
120
143
yield batched_args , in_dims , kwarg_values
121
144
122
145
123
- def get_exhaustive_batched_inputs_batch_norm (arg_values , kwarg_values , batch_size = 3 , bdims = (0 , - 1 )):
124
- for_batch_norm = True
125
- assert bdims == (0 ,) or bdims == (0 , - 1 )
126
-
127
- def add_batch_dim (arg , bdim , batch_size = 3 ):
128
- assert bdim == 0 or bdim == - 1
129
- if isinstance (arg , torch .Tensor ):
130
- if bdim == 0 :
131
- shape = [1 ] * len (arg .shape )
132
- shape .insert (bdim , batch_size )
133
- return (arg .repeat (shape ), bdim )
134
- if bdim == - 1 :
135
- arg = arg .unsqueeze (- 1 ).expand (* arg .shape , batch_size ).contiguous ()
136
- return (arg , bdim )
137
- assert False
138
- else :
139
- return (arg , None )
140
- for bdim in bdims :
141
- batch_choices = []
142
-
143
- def add_batch_choices (a ):
144
- if isinstance (a , torch .Tensor ):
145
- batched_val = add_batch_dim (a , bdim , batch_size )
146
- batch_choices .append ((batched_val , (a , None )))
147
- else :
148
- batch_choices .append (((a , None ),))
149
-
150
- flat_args , arg_spec = pytree .tree_flatten (tuple (arg_values ))
151
- if for_batch_norm :
152
- # Batch norm is unique because the running_mean and running_var are updated in place.
153
- # Therefore, they cannot be unbatched if the input is batched. The case where both are
154
- # unbatched is added at the end
155
- if len (flat_args ) >= 3 :
156
- add_batch_choices (flat_args [0 ]) # input can be batched or unbatched
157
- batch_choices .append ((add_batch_dim (flat_args [1 ], bdim , batch_size ),)) # running_mean must be batched
158
- batch_choices .append ((add_batch_dim (flat_args [2 ], bdim , batch_size ),)) # running_var must be batched
159
- orig_flat_args = flat_args
160
- flat_args = orig_flat_args [3 :]
161
- else :
162
- # TODO: None defaults in instance norm create empty tensors that are written to and mean that we must
163
- # have unbatched inputs. None in the running mean/running var shouldn't make a tensor
164
- batch_choices .append (((flat_args [0 ], None ),)) # input must be unbatched
165
- if len (flat_args ) == 2 :
166
- batch_choices .append ((add_batch_dim (flat_args [1 ], bdim , batch_size ),))
167
- orig_flat_args = flat_args
168
- flat_args = []
169
-
170
- for arg in flat_args :
171
- add_batch_choices (arg )
172
-
173
- for batched_values in itertools .product (* batch_choices ):
174
- batched_args , in_dims = zip (* batched_values )
175
-
176
- if all ([i is None for i in in_dims ]):
177
- continue
178
-
179
- yield pytree .tree_unflatten (batched_args , arg_spec ), pytree .tree_unflatten (in_dims , arg_spec ), kwarg_values
146
+ def get_exhaustive_batched_inputs_batch_norm (arg_values , kwarg_values , batch_size = 2 ):
147
+ flat_args , arg_spec = pytree .tree_flatten (tuple (arg_values ))
148
+ is_tensors = [isinstance (a , torch .Tensor ) for a in flat_args ]
149
+ num_tensors = sum (is_tensors )
150
+ if num_tensors == 1 : # if there's only an input, can't batch it since running_mean/var will be seen as unbatched tensors
151
+ return
152
+ bdim_choices = get_bdim_choices_batch_norm (num_tensors , * arg_values )
180
153
181
- if for_batch_norm and len (orig_flat_args ) >= 2 :
182
- # Adds the case where input, running_mean, and running_var are all unbatched
183
- batch_choices [0 ] = ((orig_flat_args [0 ], None ),)
184
- batch_choices [1 ] = ((orig_flat_args [1 ], None ),)
185
- if len (orig_flat_args ) >= 3 :
186
- batch_choices [2 ] = ((orig_flat_args [2 ], None ),)
187
- for batched_values in itertools .product (* batch_choices ):
188
- batched_args , in_dims = zip (* batched_values )
154
+ @memoize
155
+ def get_batched_arg (arg , bdim ):
156
+ assert isinstance (arg , torch .Tensor )
157
+ assert bdim is not None
158
+ result , _ = add_batch_dim (arg , bdim , batch_size )
159
+ return result
189
160
190
- if all ([ i is None for i in in_dims ]) :
191
- continue
161
+ for bdim_choice in bdim_choices :
162
+ flat_in_dims = construct_in_dims ( bdim_choice , is_tensors )
192
163
193
- batched_args_tuple = pytree .tree_unflatten (batched_args , arg_spec )
194
- in_dims_tuple = pytree .tree_unflatten (in_dims , arg_spec )
195
- yield batched_args_tuple , in_dims_tuple , kwarg_values
164
+ flat_batched_args = tuple (arg if in_dim is None else get_batched_arg (arg , in_dim )
165
+ for arg , in_dim in zip (flat_args , flat_in_dims ))
166
+ batched_args = pytree .tree_unflatten (flat_batched_args , arg_spec )
167
+ in_dims = pytree .tree_unflatten (flat_in_dims , arg_spec )
168
+ yield batched_args , in_dims , kwarg_values
196
169
197
170
198
171
def get_fallback_and_vmap_exhaustive (op , arg_values , kwarg_values , opinfo = None , compute_loop_out = True ):
199
172
out_dim = 0
200
173
batch_size = 2
201
174
batch_norm_fns = ("nn.functional.batch_norm" , "nn.functional.instance_norm" ) # instance norm calls batch norm
202
- for_batch_norm = opinfo is not None and opinfo .name in batch_norm_fns
203
- generator = get_exhaustive_batched_inputs (arg_values , kwarg_values , batch_size , for_batch_norm = for_batch_norm )
175
+
176
+ if opinfo is not None and opinfo .name in batch_norm_fns :
177
+ generator = get_exhaustive_batched_inputs_batch_norm (arg_values , kwarg_values , batch_size )
178
+ else :
179
+ generator = get_exhaustive_batched_inputs (arg_values , kwarg_values , batch_size )
180
+
204
181
for batched_args , in_dims , kwarg_values in generator :
205
182
if compute_loop_out :
206
183
loop_out = loop (op , in_dims , out_dim , batch_size , * batched_args , ** kwarg_values )
0 commit comments