in distributed_training/src_dir/dis_util.py [0:0]
def dist_model(model, args):
if args.multigpus_distributed:
# if args.sync_bn:
# # import apex
# print("using apex synced BN")
# model = apex.parallel.convert_syncbn_model(model)
if args.local_rank is not None:
torch.cuda.set_device(args.local_rank)
if not (args.apex or args.data_parallel or args.model_parallel):
model.cuda(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.rank])
else:
if not (args.apex or args.data_parallel or args.model_parallel):
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model)
elif args.rank is not None:
torch.cuda.set_device(args.rank)
if not (args.apex or args.data_parallel or args.model_parallel):
model = model.cuda(args.rank)
else:
if not (args.apex or args.data_parallel or args.model_parallel):
model = torch.nn.DataParallel(model).cuda()
return model, args