in sparse_autoencoder/train.py [0:0]
def sharded_topk(x, k, sh_comm, capacity_factor=None):
batch = x.shape[0]
if capacity_factor is not None:
k_in = min(int(k * capacity_factor // sh_comm.size()), k)
else:
k_in = k
topk = torch.topk(x, k=k_in, dim=-1)
inds = topk.indices
vals = topk.values
if sh_comm is None:
return inds, vals
all_vals = torch.empty(sh_comm.size(), batch, k_in, dtype=vals.dtype, device=vals.device)
sh_comm.all_gather(all_vals, vals, async_op=True)
all_vals = all_vals.permute(1, 0, 2) # put shard dim next to k
all_vals = all_vals.reshape(batch, -1) # flatten shard into k
all_topk = torch.topk(all_vals, k=k, dim=-1)
global_topk = all_topk.values
dummy_vals = torch.zeros_like(vals)
dummy_inds = torch.zeros_like(inds)
my_inds = torch.where(vals >= global_topk[:, [-1]], inds, dummy_inds)
my_vals = torch.where(vals >= global_topk[:, [-1]], vals, dummy_vals)
return my_inds, my_vals