in adaptive_io.py [0:0]
def forward(self, hidden, target):
"""
Input:
- `hidden` FloatTensor(shape + (d_proj,))
- `target` LongTensor(shape)
Output:
- `nll` FloatTensor(shape)
"""
assert hidden.shape[-1] == self.d_proj
assert hidden.shape[:-1] == target.shape
shape = target.shape
hidden = hidden.view(-1, self.d_proj)
target = target.view(-1)
# construct weights and biases
weights, biases = [], []
for i in range(len(self.cutoffs)):
weight_i = self.out_layers[i].weight
bias_i = self.out_layers[i].bias
if i == 0:
weight_i = torch.cat([weight_i, self.cluster_proj.weight], dim=0)
bias_i = torch.cat([bias_i, self.cluster_proj.bias], dim=0)
weights.append(weight_i)
biases.append(bias_i)
# head / cluster assignments
head_logit = self._compute_logit(hidden, weights[0], biases[0], self.out_projs[0])
head_logprob = F.log_softmax(head_logit.float(), dim=1)
# final log-probabilities
nll = torch.zeros_like(target, dtype=torch.float32, device=hidden.device)
offset = 0
cutoff_values = [0] + self.cutoffs
# for each cluster
for i in range(len(cutoff_values) - 1):
# select the target tokens in that cluster
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
mask_i = (target >= l_idx) & (target < r_idx)
indices_i = mask_i.nonzero().squeeze()
# if there are not any, there is nothing to do
if indices_i.numel() == 0:
continue
# index in current cluster
target_i = target.index_select(0, indices_i) - l_idx
head_logprob_i = head_logprob.index_select(0, indices_i)
if i == 0:
# for targets in the head cluster, there is just the head score
logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
else:
# otherwise, we sum the cluster assignment (head) and target scores
hidden_i = hidden.index_select(0, indices_i)
tail_logit_i = self._compute_logit(hidden_i, weights[i], biases[i], self.out_projs[i])
tail_logprob_i = F.log_softmax(tail_logit_i.float(), dim=1)
logprob_i = head_logprob_i[:, -i] + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
# populate output
nll.index_copy_(0, indices_i, -logprob_i)
offset += logprob_i.size(0)
return nll.view(shape)