kilt/readers/t5/finetune.py [101:124]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        pad_token_id = self.tokenizer.pad_token_id

        source_ids, source_mask, y = KiltDataset.trim_seq2seq_batch(batch, pad_token_id)
        generated_ids = self.model.generate(
            input_ids=source_ids,
            attention_mask=source_mask,
            num_beams=1,
            max_length=self.target_length,
            repetition_penalty=1,
            length_penalty=1.0,
            early_stopping=True,
            use_cache=True,
            do_sample=False,
            top_p=0.95,
            top_k=50,
            bad_words_ids=self.bad_words
        )

        preds = [self.tokenizer.decode(g) for g in generated_ids]
        target = [self.tokenizer.decode(t) for t in y]
        loss = self._step(batch)
        sources = [self.tokenizer.decode(s) for s in source_ids]

        return {"val_loss": loss, 'sources': sources, "preds": preds, "target": target, "ids": batch["ids"]}
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



kilt/readers/t5/finetune.py [151:174]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        pad_token_id = self.tokenizer.pad_token_id

        source_ids, source_mask, y = KiltDataset.trim_seq2seq_batch(batch, pad_token_id)
        # NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_kilt_task.py

        generated_ids = self.model.generate(
            input_ids=source_ids,
            attention_mask=source_mask,
            num_beams=1,
            max_length=self.target_length,
            repetition_penalty=1,
            length_penalty=1.0,
            early_stopping=True,
            use_cache=True,
            do_sample=False,
            top_p=0.95,
            top_k=50,
            bad_words_ids=self.bad_words
        )
        preds = [self.tokenizer.decode(g) for g in generated_ids]
        target = [self.tokenizer.decode(t) for t in y]
        loss = self._step(batch)
        sources = [self.tokenizer.decode(s) for s in source_ids]
        return {"val_loss": loss, 'sources': sources, "preds": preds, "target": target, "ids": batch["ids"]}
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



