def forward()

in adaptive_io.py [0:0]


    def forward(self, indices):
        param = self.emb_layers[0].weight.data
        idx_flat = indices.contiguous().view(-1)
        emb_flat = torch.zeros([idx_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device)

        # for each cluster
        for i in range(len(self.cutoffs)):
            # find elements in that cluster
            l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
            mask_i = (idx_flat >= l_idx) & (idx_flat < r_idx)

            # if there are no elements, continue
            indices_i = mask_i.nonzero().squeeze()
            if indices_i.numel() == 0:
                continue

            # add embeddings from this cluster
            idx_i = idx_flat.index_select(0, indices_i) - l_idx
            emb_i = self.emb_layers[i](idx_i)
            emb_i = F.linear(emb_i, self.emb_projs[i])
            emb_flat = emb_flat.type_as(emb_i) if emb_flat.dtype != emb_i.dtype else emb_flat  # small hack for AMP-O1
            emb_flat.index_copy_(0, indices_i, emb_i)

        # reshape embeddings
        embed = emb_flat.view(*indices.size(), self.d_proj)

        # rescale embeddings
        embed.mul_(self.emb_scale)

        return embed