def predict_single()

in assets/training/model_evaluation/src_distributed/model_prediction.py [0:0]


    def predict_single(self, data):
        """Predict single batch.

        Args:
            data (_type_): _description_

        Raises:
            exception: _description_

        Returns:
            _type_: _description_
        """
        X_test, y_test = data
        try:
            input_texts = X_test.values.tolist()
            if isinstance(input_texts[0], list):
                if self.task_type == SupportedTask.CHAT_COMPLETION:
                    input_data = []
                    add_generation_prompt = self.extra_params.pop("add_generation_prompt", True)
                    for itext in input_texts:
                        input_data.append(self.tokenizer.apply_chat_template(itext[0], tokenize=False, add_generation_prompt=add_generation_prompt))
                    input_texts = input_data[:]
                    self.extra_params.update({"return_full_text": False})
                    payload = MIRPayload(input_texts, self.extra_params, TaskType.CONVERSATIONAL, False)
                else:
                    input_texts = [i[0] if len(i) == 1 else [j.strip() for j in i] for i in input_texts]
                    if self.task_type == SupportedTask.TEXT_GENERATION:
                        if "return_full_text" not in self.extra_params:
                            self.extra_params["return_full_text"] = False
                    if self.task_type == SupportedTask.QnA:
                        self.extra_params.update({"truncation":"longest_first"})
                    data = {
                            "input_data": {
                                "input_string": input_texts,
                                "parameters": self.extra_params,
                            }
                    }
                    payload = MIRPayload.from_dict(data)
                    payload.update_params(get_generator_params(payload.params))
                    try: 
                        inference_results = self.engine.run(payload)
                    except:
                        try:
                            logger.info("Failed with longest_first")
                            payload.params["truncation"] = "only_second"
                            inference_results = self.engine.run(payload)
                        except:
                            logger.info("Failed with only first")
                            payload.params["truncation"] = "only_first"
                            inference_results = self.engine.run(payload)
            

            
            logger.info(
                f"Processing new request with parameters: {payload.params}"
            )

            inference_results = None
            if self.task_type == SupportedTask.CHAT_COMPLETION:
                payload.convert_query_to_list()
                start_ms = time.time() * 1000
                inference_results = self.engine.run(payload)
                end_ms = time.time() * 1000
                outputs = [res.response for i, res in enumerate(inference_results)]
                pred_probas = [res.scores for res in inference_results]
            else:
                start_ms = time.time() * 1000
                inference_results = self.engine.run(payload)
                end_ms = time.time() * 1000
                if self.task_type == SupportedTask.TEXT_GENERATION:
                    outputs = []
                    for gt, res in zip(input_texts, inference_results):
                        if gt in res.response:
                            outputs.append(res.response[len(gt):])
                        else:
                            outputs.append(res.response)
                else:
                    outputs = [res.response for i, res in enumerate(inference_results)]
                pred_probas = [res.scores for res in inference_results]
                    


            perf_data = [{
                PerformanceColumns.BATCH_SIZE_COLUMN_NAME: len(input_texts),
                PerformanceColumns.START_TIME_COLUMN_NAME: datetime.fromtimestamp(start_ms / 1000, timezone.utc).isoformat(),
                PerformanceColumns.END_TIME_COLUMN_NAME: datetime.fromtimestamp(end_ms / 1000, timezone.utc).isoformat(),
                PerformanceColumns.LATENCY_COLUMN_NAME: end_ms - start_ms,
                PerformanceColumns.OUTPUT_TOKENS_COLUMN_NAME: len(self.tokenizer(pred)) if self.tokenizer is not None else 0,
                PerformanceColumns.OUTPUT_CHARACTERS_COLUMN_NAME: len(pred) if isinstance(pred, str) else 1,
                PerformanceColumns.INPUT_CHARACTERS_COLUMN_NAME: len(gt) if isinstance(gt, str) else 1,
                PerformanceColumns.INPUT_TOKENS_COLUMN_NAME: len(self.tokenizer(gt)) if self.tokenizer is not None else 0
            } for gt, pred in zip(input_texts, outputs)]
            pred_proba_df = pd.DataFrame(pred_probas, index=X_test.index)
            perf_data = pd.DataFrame(perf_data)

            if self.task_type == SupportedTask.CHAT_COMPLETION or self.task_type == TaskType.CONVERSATIONAL:
                pred_df = self._make_chat_completion_data(X_test.copy(deep=True), outputs,
                                                          col_name=ChatCompletionConstants.OUTPUT_FULL_CONVERSATION)
                pred_df[ChatCompletionConstants.OUTPUT] = outputs
                y_test = pd.DataFrame(y_test, columns=["ground_truth"], index=X_test.index)
                # y_test = self._make_chat_completion_data(X_test.copy(deep=True), y_test, col_name="ground_truth")
                return pred_df, y_test, perf_data, pred_proba_df

            pred_df = pd.DataFrame(outputs, index=X_test.index, columns=["prediction"])
            if isinstance(y_test, pd.Series):
                y_test = y_test.to_frame()
            elif isinstance(y_test, np.ndarray) or isinstance(y_test, list):
                y_test = pd.DataFrame(y_test, index=X_test.index)
            return pred_df, y_test, perf_data, pred_proba_df

        except Exception as e:
            exception = get_azureml_exception(PredictException, ModelPredictionInternalError, e,
                                              wrap_azureml_ex=False, error=repr(e))
            log_traceback(exception, logger)
            raise exception