in low_rank_comparisons/src/gpu.py [0:0]
def parse_gpu(args):
torch.manual_seed(args.random_seed)
if args.platform == 'local':
dist.init_process_group(backend='nccl')
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
args.rank = local_rank
args.device = device
args.world_size = torch.distributed.get_world_size()
args.dist = dist
elif args.platform == 'azure':
import horovod.torch as hvd
hvd.init()
print('azure hvd rank', hvd.rank(), 'local rank', hvd.local_rank())
local_rank = hvd.local_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
rank = hvd.rank()
world_size = hvd.size()
args.local_rank = local_rank
args.rank = rank
args.device = device
args.world_size = world_size
args.hvd = hvd
elif args.platform == 'philly':
local_rank = args.local_rank
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = torch.distributed.get_world_size()
device = torch.device("cuda", local_rank)
args.rank = rank
args.device = device
args.world_size = world_size
args.dist = dist
elif args.platform == 'k8s':
master_uri = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
args.local_rank = local_rank
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
rank = world_rank
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend='nccl',
init_method=master_uri,
world_size=world_size,
rank=world_rank,
)
device = torch.device("cuda", local_rank)
args.rank = rank
args.device = device
args.world_size = world_size
args.dist = dist
print("myrank: ", args.rank, 'local_rank: ', args.local_rank, " device_count: ", torch.cuda.device_count(), "world_size:", args.world_size)