def preprocess_function()

in src/jobs/tune_t5.py [0:0]


    def preprocess_function(self, examples, teacher_model=None):
        inputs = examples["input_text"]
        model_inputs = self.tokenizer(
            inputs,
            max_length=512,
            truncation=True,
            padding="max_length"
        )
        if teacher_model is None:
            targets = examples["target_text"]
            with self.tokenizer.as_target_tokenizer():
                targets_tokenized = self.tokenizer(
                    targets,
                    max_length=64,
                    truncation=True,
                    padding="max_length"
                )
            labels = targets_tokenized["input_ids"]
        else:
            generated_ids = teacher_model.generate(
                self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.device),
                max_length=64
            ).to("cpu")
            decoded_outputs = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            print(f"Sample of decoded outputs from Unlabeled Dataset {decoded_outputs[:20]}")
            with self.tokenizer.as_target_tokenizer():
                tokenized_labels = self.tokenizer(
                    decoded_outputs,
                    max_length=64,
                    truncation=True,
                    padding="max_length"
                )
            labels = tokenized_labels["input_ids"]
        model_inputs["labels"] = labels

        return model_inputs