def predict()

in assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/llava/llava_mlflow_wrapper.py [0:0]


    def predict(self, context: mlflow.pyfunc.PythonModelContext, input_data: pd.DataFrame) -> pd.DataFrame:
        """Perform inference on the input data.

        :param context: MLflow context containing artifacts that the model can use for inference
        :type context: mlflow.pyfunc.PythonModelContext
        :param input_data: Pandas DataFrame with columns ["image"], ["prompt"] and ["direct_question"], where
                           the image is either a url or a base64 string, the prompt is the dialog so far between the
                           user and the model and the direct question is a prompt with a single question from the user.
        :type input_data: pd.DataFrame
        :return: Pandas dataframe with column ["response"] containing the model's response to the dialog so far.
        """
        from llava.constants import IMAGE_TOKEN_INDEX
        from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
        from vision_utils import process_image

        # Do inference one input at a time.
        responses = []
        for image, prompt, direct_question in zip(
            input_data[MLflowSchemaLiterals.INPUT_COLUMN_IMAGE],
            input_data[MLflowSchemaLiterals.INPUT_COLUMN_PROMPT],
            input_data[MLflowSchemaLiterals.INPUT_COLUMN_DIRECT_QUESTION],
        ):
            # Decode the image and make a PIL Image object.
            pil_image = Image.open(io.BytesIO(process_image(image)))

            # If prompt not specified, make prompt from direct question column.
            if not prompt:
                prompt_from_direct_question = True
                if self._model_version == self.LLAVA_MPT:
                    prompt = (
                        f"A conversation between a user and an LLM-based AI assistant. The assistant gives helpful "
                        f"and honest answers.<|im_end|><|im_start|>user\n<im_start><image><im_end>\n"
                        f"{direct_question}<|im_end|><|im_start|>assistant"
                    )
                elif self._model_version == self.LLAVA_7B:
                    prompt = (
                        f"[INST] <<SYS>>\nYou are a helpful language and vision assistant. You are able to understand "
                        f"the visual content that the user provides, and assist the user with a variety of tasks "
                        f"using natural language.\n<</SYS>>\n\n<im_start><image><im_end>\n{direct_question} [/INST]"
                    )
                elif self._model_version == self.LLAVA_7B_15:
                    prompt = (
                        f"A chat between a curious human and an artificial intelligence assistant. "
                        f"The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "
                        f"<image>\n{direct_question} ASSISTANT:"
                    )
                elif self._model_version == self.LLAVA_13B:
                    prompt = (
                        f"A chat between a curious human and an artificial intelligence assistant. "
                        f"The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "
                        f"<im_start><image><im_end>\n{direct_question} ASSISTANT:"
                    )
                elif self._model_version == self.LLAVA_13B2:
                    prompt = (
                        f"[INST] <<SYS>>\nYou are a helpful language and vision assistant. You are able to understand "
                        f"the visual content that the user provides, and assist the user with a variety of tasks "
                        f"using natural language.\n<</SYS>>\n\n<image>\n{direct_question} [/INST]"
                    )
                elif self._model_version == self.LLAVA_13B_15:
                    prompt = (
                        f"A chat between a curious human and an artificial intelligence assistant. "
                        f"The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "
                        f"<image>\n{direct_question} ASSISTANT:"
                    )
            else:
                prompt_from_direct_question = False

            # Make image input.
            image_tensor = self._image_processor.preprocess(
                pil_image, return_tensors="pt"
            )["pixel_values"].half().cuda()

            # Make text input.
            input_ids = tokenizer_image_token(
                prompt, self._tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
            ).unsqueeze(0).cuda()
            stopping_criteria = KeywordsStoppingCriteria([self._stop_str], self._tokenizer, input_ids)

            # For small models on V100 machines, long prompts cause a GPU OOMs which the server does not recover from.
            # To prevent this, we are using a length threshold that allows for a small number of question-answer pairs
            # (e.g. 5-10) in each prompt.
            if self._model_version in [self.LLAVA_MPT, self.LLAVA_7B, self.LLAVA_7B_15]:
                prompt_length = max([len(i) for i in input_ids])
                if prompt_length > MAX_PROMPT_LENGTH:
                    raise ValueError(
                        f"Prompt too long: {prompt_length} tokens. Maximum allowed is {MAX_PROMPT_LENGTH}."
                    )

            # Call model.
            output_ids = self._model.generate(
                input_ids,
                images=image_tensor,
                do_sample=True,
                temperature=0.2,
                max_new_tokens=1024,
                streamer=self._streamer,
                use_cache=True,
                stopping_criteria=[stopping_criteria]
            )

            # Convert response to text and trim.
            response = self._tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
            if prompt_from_direct_question:
                if response.startswith(": "):
                    response = response[2:]
                if response.endswith(self._stop_str):
                    response = response[:-len(self._stop_str)]
                if response.endswith("</s>"):
                    response = response[:-len("</s>")]

            # Accumulate into response list.
            responses.append(response)

        # Convert responses to Pandas dataframe.
        df_responses = pd.DataFrame({MLflowSchemaLiterals.OUTPUT_COLUMN_RESPONSE: responses})
        return df_responses