in sparse_autoencoder/train.py [0:0]
def make_torch_comms(n_op_shards=4, n_replicas=2):
if "RANK" not in os.environ:
assert n_op_shards == 1
assert n_replicas == 1
return TRIVIAL_COMMS
rank = int(os.environ.get("RANK"))
world_size = int(os.environ.get("WORLD_SIZE", 1))
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank % 8)
print(f"{rank=}, {world_size=}")
dist.init_process_group("nccl")
my_op_shard_idx = rank % n_op_shards
my_replica_idx = rank // n_op_shards
shard_rank_lists = [list(range(i, i + n_op_shards)) for i in range(0, world_size, n_op_shards)]
shard_groups = [dist.new_group(shard_rank_list) for shard_rank_list in shard_rank_lists]
my_shard_group = shard_groups[my_replica_idx]
replica_rank_lists = [
list(range(i, n_op_shards * n_replicas, n_op_shards)) for i in range(n_op_shards)
]
replica_groups = [dist.new_group(replica_rank_list) for replica_rank_list in replica_rank_lists]
my_replica_group = replica_groups[my_op_shard_idx]
torch.distributed.all_reduce(torch.ones(1).cuda())
torch.cuda.synchronize()
dp_comm = Comm(group=my_replica_group)
sh_comm = Comm(group=my_shard_group)
return ShardingComms(
n_replicas=n_replicas,
n_op_shards=n_op_shards,
dp_comm=dp_comm,
sh_comm=sh_comm,
dp_rank=my_replica_idx,
sh_rank=my_op_shard_idx,
_rank=rank,
)