def compute_loss()

in projects/dialogue_unlikelihood/agents.py [0:0]


    def compute_loss(self, batch, return_output=False):
        if self._last_was_training is not self.is_training:
            self._reset_running_histories()
            self._last_was_training = self.is_training

        nll_loss, model_output = super().compute_loss(batch, True)
        scores, preds, *_ = model_output  # scores is bsz x time x vocab
        targets = batch.label_vec
        notnull = targets != self.NULL_IDX

        with torch.no_grad():
            beam_pred_scores, _ = self._generate(
                batch, self.beam_size, self.opt['label_truncate']
            )

            # forward pass to create graph for beam search case
            generations = [g for (g, s, _) in beam_pred_scores]
            gentoks = torch.nn.utils.rnn.pad_sequence(
                generations, batch_first=True, padding_value=self.NULL_IDX
            )
            # strip the BOS tokens
            gentoks = gentoks[:, 1:]

        # find everything we oversampled
        gen_mask = gentoks != self.NULL_IDX
        self.generation_history.append(Counter(gentoks[gen_mask].view(-1).tolist()))
        self.human_history.append(Counter(targets[notnull].view(-1).tolist()))
        self.running_generation += self.generation_history[-1]
        self.running_human += self.human_history[-1]

        if len(self.generation_history) > self.NUM_STEPS:
            if not self.is_training:
                # we want a running history of word usage
                self.running_generation -= self.generation_history.pop(0)
                self.running_human -= self.human_history.pop(0)
        else:
            if return_output:
                return nll_loss, model_output
            else:
                return nll_loss

        gen_sum = sum(self.running_generation.values())
        hum_sum = sum(self.running_human.values())

        # what did we oversample?
        if self.opt['weighting'] == 'logdiff':
            to_penalize = {
                w: (v / gen_sum) - (self.running_human.get(w, 0) / hum_sum)
                for w, v in self.running_generation.items()
            }
            to_penalize = {
                w: v for w, v in to_penalize.items() if v >= self.opt['threshold']
            }
            to_penalize = {w: math.log(v / 0.001) for w, v in to_penalize.items()}
        elif self.opt['weighting'] == 'uniform':
            to_penalize = {
                w: (v / gen_sum) - (self.running_human.get(w, 0) / hum_sum)
                for w, v in self.running_generation.items()
            }
            to_penalize = {
                w: 1 for w, v in to_penalize.items() if v >= self.opt['threshold']
            }
        elif self.opt['weighting'] == 'kldiv':
            to_penalize = {
                w: (
                    self.running_generation[w] / gen_sum,
                    self.running_human[w] / hum_sum,
                )
                for w, v in self.running_human.items()
                if w in self.running_generation
            }
            to_penalize = {
                w: (p_gen, p_hum)
                for w, (p_gen, p_hum) in to_penalize.items()
                if p_gen > p_hum
            }
            to_penalize = {
                w: p_gen * (math.log(p_gen) - math.log(p_hum))
                for w, (p_gen, p_hum) in to_penalize.items()
            }
            to_penalize = {
                k: v for k, v in to_penalize.items() if v > self.opt['threshold']
            }
        else:
            raise ValueError

        self.global_metrics.add('num_penalize', SumMetric(len(to_penalize)))

        ul_weights = torch.zeros(gen_mask.shape)
        ul_mask = torch.zeros_like(gen_mask)
        for wordid, weight in to_penalize.items():
            ul_mask = ul_mask | (gentoks == wordid)
            ul_weights[gentoks == wordid] = weight
        ul_weights = ul_weights.to(gen_mask.device)
        self.global_metrics.add('ul_weights', AverageMetric(ul_weights[ul_mask].mean()))

        # and whack it
        model_output = self.model(*self._model_input(batch), ys=gentoks)
        scores, *_ = model_output
        downweight = gentoks[ul_mask]

        almost_scores = F.log_softmax(scores[ul_mask], dim=-1)
        ul_scores = almost_scores[torch.arange(len(downweight)), downweight]

        clamp_min = 1e-6 if self.opt['fp16'] else 1e-20

        ul_loss = (
            -(torch.log(torch.clamp(1 - ul_scores.exp(), min=clamp_min)))
            * ul_weights[ul_mask]
        ).sum()
        num_ul = ul_mask.sum()

        self.global_metrics.add('ul_loss', AverageMetric(ul_loss, num_ul))
        self.global_metrics.add('ul_num_tokens', SumMetric(num_ul))

        ul_loss = div(ul_loss, num_ul)

        if len(self.generation_history) < self.NUM_STEPS:
            loss = nll_loss
        else:
            loss = nll_loss + self.opt['alpha'] * ul_loss

        if return_output:
            return (loss, model_output)
        else:
            return loss