def inference()

in community-content/vertex_model_garden/model_oss/transformers/handler.py [0:0]


  def inference(self, data: Any, *args, **kwargs) -> List[Any]:
    """Run the inference."""
    texts, images = data
    preds = None
    if self.task == ZERO_CLASSIFICATION:
      preds = self.pipeline(images=images, candidate_labels=texts)
    elif self.task == ZERO_DETECTION:
      # The object detection pipeline doesn't support batch prediction.
      preds = self.pipeline(image=images[0], candidate_labels=texts[0])
    elif self.task == IMAGE_CAPTIONING:
      if self.salesforce_blip:
        inputs = self.processor(images[0], return_tensors="pt").to(
            self.map_location, self.torch_type
        )
        preds = self.model.generate(**inputs)
        preds = [
            self.processor.decode(preds[0], skip_special_tokens=True).strip()
        ]
      else:
        preds = self.pipeline(images=images)
    elif self.task == VQA:
      # The VQA pipelines doesn't support batch prediction.
      if self.salesforce_blip:
        inputs = self.processor(images[0], texts[0], return_tensors="pt").to(
            self.map_location, self.torch_type
        )
        preds = self.model.generate(**inputs)
        preds = [
            self.processor.decode(preds[0], skip_special_tokens=True).strip()
        ]
      else:
        preds = self.pipeline(image=images[0], question=texts[0])
    elif self.task == DQA:
      # The DQA pipelines doesn't support batch prediction.
      preds = self.pipeline(image=images[0], question=texts[0])
    elif self.task == FEATURE_EMBEDDING:
      preds = {}
      if texts:
        inputs = self.tokenizer(
            text=texts, padding=True, return_tensors="pt"
        ).to(self.map_location)
        text_features = self.model.get_text_features(**inputs)
        preds["text_features"] = text_features.detach().cpu().numpy().tolist()
      if images:
        inputs = self.processor(images=images, return_tensors="pt").to(
            self.map_location
        )
        image_features = self.model.get_image_features(**inputs)
        preds["image_features"] = image_features.detach().cpu().numpy().tolist()
      preds = [preds]
    elif self.task == SUMMARIZATION and FLAN_T5 in self.model_id:
      texts = [SUMMARIZATION_TEMPLATE.format(input=text) for text in texts]
      preds = self.pipeline(texts, max_length=130)
    elif self.task == SUMMARIZATION and self.model_id == BART_LARGE_CNN:
      preds = self.pipeline(
          texts[0], max_length=130, min_length=30, do_sample=False
      )
    else:
      raise ValueError(f"Invalid TASK: {self.task}")
    return preds