def preprocess()

in community-content/vertex_model_garden/model_oss/llava/handler.py [0:0]


  def preprocess(self, data: List[Dict[str, Any]]) -> Any:
    """Runs the preprocessing to tokenize image and the prompt."""
    if len(data) > 1:
      raise ValueError(
          "LLava original repo currently does not support batch inference."
          " https://github.com/haotian-liu/LLaVA/issues/754"
      )
    data = data[0]
    prompt, base64_image = data["prompt"], data["base64_image"]

    # Adds proper image token to the prompt.
    image_token_se = (
        llava_constants.DEFAULT_IM_START_TOKEN
        + llava_constants.DEFAULT_IMAGE_TOKEN
        + llava_constants.DEFAULT_IM_END_TOKEN
    )
    if llava_constants.IMAGE_PLACEHOLDER in prompt:
      if self.model.config.mm_use_im_start_end:
        prompt = re.sub(
            llava_constants.IMAGE_PLACEHOLDER, image_token_se, prompt
        )
      else:
        prompt = re.sub(
            llava_constants.IMAGE_PLACEHOLDER,
            llava_constants.DEFAULT_IMAGE_TOKEN,
            prompt,
        )
    else:
      if self.model.config.mm_use_im_start_end:
        prompt = image_token_se + "\n" + prompt
      else:
        prompt = llava_constants.DEFAULT_IMAGE_TOKEN + "\n" + prompt

    # Formats the prompt as a conversation to be fed to the model.
    conv = conversation.conv_llava_v1.copy()
    conv.append_message(role=conv.roles[0], message=prompt)
    conv.append_message(role=conv.roles[1], message=None)
    prompt = conv.get_prompt()

    # Tokenizes the prompt that includes special image token as well.
    input_ids = (
        mm_utils.tokenizer_image_token(
            prompt=prompt,
            tokenizer=self.tokenizer,
            image_token_index=llava_constants.IMAGE_TOKEN_INDEX,
            return_tensors="pt",
        )
        .unsqueeze(0)
        .to(self.device)
    )

    images = [
        image_format_converter.base64_to_image(image_str=base64_image).convert(
            "RGB"
        )
    ]
    # Gets the image embedding.
    images_tensor = mm_utils.process_images(
        images=images,
        image_processor=self.image_processor,
        model_cfg=self.model.config,
    ).to(self.device, dtype=torch.float16)

    self.stop_str = conversation.conv_llava_v1.sep2
    self.keywords = [self.stop_str]

    return input_ids, images_tensor