in community-content/vertex_model_garden/model_oss/transformers/handler.py [0:0]
def inference(self, data: Any, *args, **kwargs) -> List[Any]:
"""Run the inference."""
texts, images = data
preds = None
if self.task == ZERO_CLASSIFICATION:
preds = self.pipeline(images=images, candidate_labels=texts)
elif self.task == ZERO_DETECTION:
# The object detection pipeline doesn't support batch prediction.
preds = self.pipeline(image=images[0], candidate_labels=texts[0])
elif self.task == IMAGE_CAPTIONING:
if self.salesforce_blip:
inputs = self.processor(images[0], return_tensors="pt").to(
self.map_location, self.torch_type
)
preds = self.model.generate(**inputs)
preds = [
self.processor.decode(preds[0], skip_special_tokens=True).strip()
]
else:
preds = self.pipeline(images=images)
elif self.task == VQA:
# The VQA pipelines doesn't support batch prediction.
if self.salesforce_blip:
inputs = self.processor(images[0], texts[0], return_tensors="pt").to(
self.map_location, self.torch_type
)
preds = self.model.generate(**inputs)
preds = [
self.processor.decode(preds[0], skip_special_tokens=True).strip()
]
else:
preds = self.pipeline(image=images[0], question=texts[0])
elif self.task == DQA:
# The DQA pipelines doesn't support batch prediction.
preds = self.pipeline(image=images[0], question=texts[0])
elif self.task == FEATURE_EMBEDDING:
preds = {}
if texts:
inputs = self.tokenizer(
text=texts, padding=True, return_tensors="pt"
).to(self.map_location)
text_features = self.model.get_text_features(**inputs)
preds["text_features"] = text_features.detach().cpu().numpy().tolist()
if images:
inputs = self.processor(images=images, return_tensors="pt").to(
self.map_location
)
image_features = self.model.get_image_features(**inputs)
preds["image_features"] = image_features.detach().cpu().numpy().tolist()
preds = [preds]
elif self.task == SUMMARIZATION and FLAN_T5 in self.model_id:
texts = [SUMMARIZATION_TEMPLATE.format(input=text) for text in texts]
preds = self.pipeline(texts, max_length=130)
elif self.task == SUMMARIZATION and self.model_id == BART_LARGE_CNN:
preds = self.pipeline(
texts[0], max_length=130, min_length=30, do_sample=False
)
else:
raise ValueError(f"Invalid TASK: {self.task}")
return preds