in projects/dialogue_unlikelihood/agents.py [0:0]
def compute_loss(self, batch, return_output=False):
if self._last_was_training is not self.is_training:
self._reset_running_histories()
self._last_was_training = self.is_training
nll_loss, model_output = super().compute_loss(batch, True)
scores, preds, *_ = model_output # scores is bsz x time x vocab
targets = batch.label_vec
notnull = targets != self.NULL_IDX
with torch.no_grad():
beam_pred_scores, _ = self._generate(
batch, self.beam_size, self.opt['label_truncate']
)
# forward pass to create graph for beam search case
generations = [g for (g, s, _) in beam_pred_scores]
gentoks = torch.nn.utils.rnn.pad_sequence(
generations, batch_first=True, padding_value=self.NULL_IDX
)
# strip the BOS tokens
gentoks = gentoks[:, 1:]
# find everything we oversampled
gen_mask = gentoks != self.NULL_IDX
self.generation_history.append(Counter(gentoks[gen_mask].view(-1).tolist()))
self.human_history.append(Counter(targets[notnull].view(-1).tolist()))
self.running_generation += self.generation_history[-1]
self.running_human += self.human_history[-1]
if len(self.generation_history) > self.NUM_STEPS:
if not self.is_training:
# we want a running history of word usage
self.running_generation -= self.generation_history.pop(0)
self.running_human -= self.human_history.pop(0)
else:
if return_output:
return nll_loss, model_output
else:
return nll_loss
gen_sum = sum(self.running_generation.values())
hum_sum = sum(self.running_human.values())
# what did we oversample?
if self.opt['weighting'] == 'logdiff':
to_penalize = {
w: (v / gen_sum) - (self.running_human.get(w, 0) / hum_sum)
for w, v in self.running_generation.items()
}
to_penalize = {
w: v for w, v in to_penalize.items() if v >= self.opt['threshold']
}
to_penalize = {w: math.log(v / 0.001) for w, v in to_penalize.items()}
elif self.opt['weighting'] == 'uniform':
to_penalize = {
w: (v / gen_sum) - (self.running_human.get(w, 0) / hum_sum)
for w, v in self.running_generation.items()
}
to_penalize = {
w: 1 for w, v in to_penalize.items() if v >= self.opt['threshold']
}
elif self.opt['weighting'] == 'kldiv':
to_penalize = {
w: (
self.running_generation[w] / gen_sum,
self.running_human[w] / hum_sum,
)
for w, v in self.running_human.items()
if w in self.running_generation
}
to_penalize = {
w: (p_gen, p_hum)
for w, (p_gen, p_hum) in to_penalize.items()
if p_gen > p_hum
}
to_penalize = {
w: p_gen * (math.log(p_gen) - math.log(p_hum))
for w, (p_gen, p_hum) in to_penalize.items()
}
to_penalize = {
k: v for k, v in to_penalize.items() if v > self.opt['threshold']
}
else:
raise ValueError
self.global_metrics.add('num_penalize', SumMetric(len(to_penalize)))
ul_weights = torch.zeros(gen_mask.shape)
ul_mask = torch.zeros_like(gen_mask)
for wordid, weight in to_penalize.items():
ul_mask = ul_mask | (gentoks == wordid)
ul_weights[gentoks == wordid] = weight
ul_weights = ul_weights.to(gen_mask.device)
self.global_metrics.add('ul_weights', AverageMetric(ul_weights[ul_mask].mean()))
# and whack it
model_output = self.model(*self._model_input(batch), ys=gentoks)
scores, *_ = model_output
downweight = gentoks[ul_mask]
almost_scores = F.log_softmax(scores[ul_mask], dim=-1)
ul_scores = almost_scores[torch.arange(len(downweight)), downweight]
clamp_min = 1e-6 if self.opt['fp16'] else 1e-20
ul_loss = (
-(torch.log(torch.clamp(1 - ul_scores.exp(), min=clamp_min)))
* ul_weights[ul_mask]
).sum()
num_ul = ul_mask.sum()
self.global_metrics.add('ul_loss', AverageMetric(ul_loss, num_ul))
self.global_metrics.add('ul_num_tokens', SumMetric(num_ul))
ul_loss = div(ul_loss, num_ul)
if len(self.generation_history) < self.NUM_STEPS:
loss = nll_loss
else:
loss = nll_loss + self.opt['alpha'] * ul_loss
if return_output:
return (loss, model_output)
else:
return loss