Skip to content

Commit e139706

Browse files
committed
Debugged grouping.
1 parent 40f8e83 commit e139706

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

model_average.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ def train(gpu, args):
6565
actual_lr = args.lr * sample_ray_num / 512 # bigger batch -> higher lr (linearity)
6666
ma_epoch = args.ma_epoch
6767
ma_method = args.ma_method
68-
group = None if not args.group else args.group
6968

7069
train_cnt, ep_start = None, None
7170

7271
rank = args.nr * args.gpus + gpu
73-
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank, group_name = group)
72+
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
73+
process_group = dist.new_group(backend = 'nccl')
7474
torch.cuda.set_device(gpu)
7575

7676
for folder in ("./output/", "./check_points/", "./model/"):
@@ -229,31 +229,31 @@ def run():
229229
train_sampler.set_epoch(train_cnt)
230230
if ep % ma_epoch == 0:
231231
# double barrier to ensure synchronized sending / receiving
232-
dist.barrier()
232+
dist.barrier(group = process_group)
233233
comm_timer.tic()
234234
print(f"Using model average, method: {args.ma_method}... ", end = '')
235235
if ma_method == 'p2p':
236236
# This is a serialized reduce - broadcast (a central node exists)
237237
if rank == 0:
238-
param_recv_avg(mip_net, container, model_weights, [1, 2, 3], group = group)
238+
param_recv_avg(mip_net, container, model_weights, [1, 2, 3], group = process_group)
239239
# Receive from multiple nodes
240-
param_send(mip_net, dist_ranks = [1, 2, 3], group = group)
240+
param_send(mip_net, dist_ranks = [1, 2, 3], group = process_group)
241241
else:
242-
param_send(mip_net, dist_ranks = [0], group = group)
242+
param_send(mip_net, dist_ranks = [0], group = process_group)
243243
# Receive from only one node
244-
param_recv(mip_net, source_rank = 0, group = group)
244+
param_recv(mip_net, source_rank = 0, group = process_group)
245245
elif ma_method == 'broadcast': # reduce-broadcast (one of the node is the bottleneck)
246-
param_reduce(mip_net, model_weights, rank, 0, group = group)
247-
param_broadcast(mip_net, 0, group = group)
246+
param_reduce(mip_net, model_weights, rank, 0, group = process_group)
247+
param_broadcast(mip_net, 0, group = process_group)
248248
elif ma_method == 'all_reduce': # all-reduce (one-step reduce-broadcast)
249249
for param in mip_net.parameters():
250250
param.data *= model_weights[rank]
251-
param_all_reduce(mip_net, group = group)
251+
param_all_reduce(mip_net, group = process_group)
252252
else:
253253
# TODO: more delicate communication strategy should be implemented
254254
# This is basically the case with correlated camera poses
255255
pass
256-
dist.barrier()
256+
dist.barrier(group = process_group)
257257
comm_timer.toc()
258258
mean_comm_time = comm_timer.get_mean_time()
259259
writer.add_scalar('Time/comm time', mean_comm_time, train_cnt)
@@ -312,8 +312,6 @@ def main():
312312
parser.add_argument('--ma_method', choices=['p2p', 'broadcast', 'delicate', 'all_reduce'], type = str, default = 'p2p',
313313
help='Model average strategies')
314314

315-
parser.add_argument('--group', default="", type=str,
316-
help='Name of the group')
317315
parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N',
318316
help='number of data loading workers (default: 4)')
319317
parser.add_argument('-g', '--gpus', default=1, type=int,

0 commit comments

Comments
 (0)