def parse_gpu()

in low_rank_comparisons/src/gpu.py [0:0]


def parse_gpu(args):
    torch.manual_seed(args.random_seed)
    
    if args.platform == 'local':
        dist.init_process_group(backend='nccl')
        local_rank = torch.distributed.get_rank()
        torch.cuda.set_device(local_rank)
        device = torch.device("cuda", local_rank)
        args.rank = local_rank
        args.device = device
        args.world_size = torch.distributed.get_world_size()
        args.dist = dist
        
    elif args.platform == 'azure':
        import horovod.torch as hvd
        hvd.init()
        print('azure hvd rank', hvd.rank(), 'local rank', hvd.local_rank())
        local_rank = hvd.local_rank()
        torch.cuda.set_device(local_rank)
        device = torch.device("cuda", local_rank)                                                 
        rank = hvd.rank()
        world_size = hvd.size()
        
        args.local_rank = local_rank
        args.rank = rank
        args.device = device
        args.world_size = world_size
        args.hvd = hvd

    elif args.platform == 'philly':
        local_rank = args.local_rank 
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend='nccl')
        rank = dist.get_rank()
        world_size = torch.distributed.get_world_size()
        device = torch.device("cuda", local_rank)     

        args.rank = rank
        args.device = device
        args.world_size = world_size
        args.dist = dist
    elif args.platform == 'k8s':
        master_uri = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
        local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
        args.local_rank = local_rank
        world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
        world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
        rank = world_rank
        torch.cuda.set_device(local_rank)
        
        dist.init_process_group(
                backend='nccl',
                init_method=master_uri,
                world_size=world_size,
                rank=world_rank,
        )
        device = torch.device("cuda", local_rank)
        args.rank = rank
        args.device = device
        args.world_size = world_size
        args.dist = dist
    print("myrank: ", args.rank, 'local_rank: ', args.local_rank, " device_count: ", torch.cuda.device_count(), "world_size:", args.world_size)