def dist_model()

in distributed_training/src_dir/dis_util.py [0:0]


def dist_model(model, args):
    if args.multigpus_distributed:
        #     if args.sync_bn:
        # #         import apex
        #         print("using apex synced BN")
        #         model = apex.parallel.convert_syncbn_model(model)

        if args.local_rank is not None:
            torch.cuda.set_device(args.local_rank)

            if not (args.apex or args.data_parallel or args.model_parallel):
                model.cuda(args.local_rank)
                model = torch.nn.parallel.DistributedDataParallel(
                    model, device_ids=[args.rank])
        else:
            if not (args.apex or args.data_parallel or args.model_parallel):
                model.cuda()
                model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.rank is not None:
        torch.cuda.set_device(args.rank)
        if not (args.apex or args.data_parallel or args.model_parallel):
            model = model.cuda(args.rank)
    else:
        if not (args.apex or args.data_parallel or args.model_parallel):
            model = torch.nn.DataParallel(model).cuda()

    return model, args