def sharded_topk()

in sparse_autoencoder/train.py [0:0]


def sharded_topk(x, k, sh_comm, capacity_factor=None):
    batch = x.shape[0]

    if capacity_factor is not None:
        k_in = min(int(k * capacity_factor // sh_comm.size()), k)
    else:
        k_in = k

    topk = torch.topk(x, k=k_in, dim=-1)
    inds = topk.indices
    vals = topk.values

    if sh_comm is None:
        return inds, vals

    all_vals = torch.empty(sh_comm.size(), batch, k_in, dtype=vals.dtype, device=vals.device)
    sh_comm.all_gather(all_vals, vals, async_op=True)

    all_vals = all_vals.permute(1, 0, 2)  # put shard dim next to k
    all_vals = all_vals.reshape(batch, -1)  # flatten shard into k

    all_topk = torch.topk(all_vals, k=k, dim=-1)
    global_topk = all_topk.values

    dummy_vals = torch.zeros_like(vals)
    dummy_inds = torch.zeros_like(inds)

    my_inds = torch.where(vals >= global_topk[:, [-1]], inds, dummy_inds)
    my_vals = torch.where(vals >= global_topk[:, [-1]], vals, dummy_vals)

    return my_inds, my_vals