in community-content/vertex_model_garden/model_oss/llava/handler.py [0:0]
def preprocess(self, data: List[Dict[str, Any]]) -> Any:
"""Runs the preprocessing to tokenize image and the prompt."""
if len(data) > 1:
raise ValueError(
"LLava original repo currently does not support batch inference."
" https://github.com/haotian-liu/LLaVA/issues/754"
)
data = data[0]
prompt, base64_image = data["prompt"], data["base64_image"]
# Adds proper image token to the prompt.
image_token_se = (
llava_constants.DEFAULT_IM_START_TOKEN
+ llava_constants.DEFAULT_IMAGE_TOKEN
+ llava_constants.DEFAULT_IM_END_TOKEN
)
if llava_constants.IMAGE_PLACEHOLDER in prompt:
if self.model.config.mm_use_im_start_end:
prompt = re.sub(
llava_constants.IMAGE_PLACEHOLDER, image_token_se, prompt
)
else:
prompt = re.sub(
llava_constants.IMAGE_PLACEHOLDER,
llava_constants.DEFAULT_IMAGE_TOKEN,
prompt,
)
else:
if self.model.config.mm_use_im_start_end:
prompt = image_token_se + "\n" + prompt
else:
prompt = llava_constants.DEFAULT_IMAGE_TOKEN + "\n" + prompt
# Formats the prompt as a conversation to be fed to the model.
conv = conversation.conv_llava_v1.copy()
conv.append_message(role=conv.roles[0], message=prompt)
conv.append_message(role=conv.roles[1], message=None)
prompt = conv.get_prompt()
# Tokenizes the prompt that includes special image token as well.
input_ids = (
mm_utils.tokenizer_image_token(
prompt=prompt,
tokenizer=self.tokenizer,
image_token_index=llava_constants.IMAGE_TOKEN_INDEX,
return_tensors="pt",
)
.unsqueeze(0)
.to(self.device)
)
images = [
image_format_converter.base64_to_image(image_str=base64_image).convert(
"RGB"
)
]
# Gets the image embedding.
images_tensor = mm_utils.process_images(
images=images,
image_processor=self.image_processor,
model_cfg=self.model.config,
).to(self.device, dtype=torch.float16)
self.stop_str = conversation.conv_llava_v1.sep2
self.keywords = [self.stop_str]
return input_ids, images_tensor