in adaptive_io.py [0:0]
def forward(self, indices):
param = self.emb_layers[0].weight.data
idx_flat = indices.contiguous().view(-1)
emb_flat = torch.zeros([idx_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device)
# for each cluster
for i in range(len(self.cutoffs)):
# find elements in that cluster
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
mask_i = (idx_flat >= l_idx) & (idx_flat < r_idx)
# if there are no elements, continue
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
# add embeddings from this cluster
idx_i = idx_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](idx_i)
emb_i = F.linear(emb_i, self.emb_projs[i])
emb_flat = emb_flat.type_as(emb_i) if emb_flat.dtype != emb_i.dtype else emb_flat # small hack for AMP-O1
emb_flat.index_copy_(0, indices_i, emb_i)
# reshape embeddings
embed = emb_flat.view(*indices.size(), self.d_proj)
# rescale embeddings
embed.mul_(self.emb_scale)
return embed