def predict()

in notebooks/packed_bert/pipeline/packed_bert.py [0:0]


    def predict(self, sentence_1, sentence_2=None):
        self.sentence_2_key = sentence_2

        prep_st = time.time()

        data_dict = {"text": sentence_1}
        if sentence_2:
            data_dict["text_2"] = sentence_2

        dataset = Dataset.from_dict(data_dict)
        enc_data = dataset.map(self.preprocess_function, batched=True)

        # Pack the inputs
        packed_data = PackedDatasetCreator(
            tokenized_dataset=enc_data,
            max_sequence_length=self.max_seq_length,
            max_sequences_per_pack=self.max_seq_per_pack,
            inference=True,
            pad_to_global_batch_size=True,
            global_batch_size=self.gbs,
            problem_type=self.problem_type,
        ).create()

        dataloader = prepare_inference_dataloader(
            self.ipu_config, packed_data, self.micro_batch_size, self.dataloader_mode
        )

        example_ids = []
        outputs = []

        # Process the model to return logits
        prep_time = time.time() - prep_st

        model_st = time.time()
        for batch in iter(dataloader):
            logits = self.poplar_executor(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                token_type_ids=batch["token_type_ids"],
                position_ids=batch["position_ids"],
            )

            ids = batch["example_ids"]
            outputs.append(logits.view(ids.shape[0], self.max_seq_per_pack, -1))
            example_ids.append(ids)

        model_en = time.time()
        model_time = model_en - model_st
        tput = len(sentence_1) / (model_time)

        # Postprocess predictions to preserve order
        post_st = time.time()
        final_preds = self.postprocess_preds(outputs, example_ids)

        if len(self.label_categories) == final_preds.shape[-1]:
            final_preds = {k: dict(list(zip(self.label_categories, v))) for k, v in enumerate(final_preds)}
        else:
            final_preds = {{n: k[n] for n in k} for k in final_preds}

        post_proc_time = time.time() - post_st

        return {
            "predictions": final_preds,
            "throughput": tput,
            "inference_total_time": model_time,
            "preprocessing_time": prep_time,
            "postprocessing_time": post_proc_time,
        }