def make_torch_comms()

in sparse_autoencoder/train.py [0:0]


def make_torch_comms(n_op_shards=4, n_replicas=2):
    if "RANK" not in os.environ:
        assert n_op_shards == 1
        assert n_replicas == 1
        return TRIVIAL_COMMS

    rank = int(os.environ.get("RANK"))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    os.environ["CUDA_VISIBLE_DEVICES"] = str(rank % 8)

    print(f"{rank=}, {world_size=}")
    dist.init_process_group("nccl")

    my_op_shard_idx = rank % n_op_shards
    my_replica_idx = rank // n_op_shards

    shard_rank_lists = [list(range(i, i + n_op_shards)) for i in range(0, world_size, n_op_shards)]

    shard_groups = [dist.new_group(shard_rank_list) for shard_rank_list in shard_rank_lists]

    my_shard_group = shard_groups[my_replica_idx]

    replica_rank_lists = [
        list(range(i, n_op_shards * n_replicas, n_op_shards)) for i in range(n_op_shards)
    ]

    replica_groups = [dist.new_group(replica_rank_list) for replica_rank_list in replica_rank_lists]

    my_replica_group = replica_groups[my_op_shard_idx]

    torch.distributed.all_reduce(torch.ones(1).cuda())
    torch.cuda.synchronize()

    dp_comm = Comm(group=my_replica_group)
    sh_comm = Comm(group=my_shard_group)

    return ShardingComms(
        n_replicas=n_replicas,
        n_op_shards=n_op_shards,
        dp_comm=dp_comm,
        sh_comm=sh_comm,
        dp_rank=my_replica_idx,
        sh_rank=my_op_shard_idx,
        _rank=rank,
    )