def postprocess()

in community-content/vertex_model_garden/model_oss/open_clip/handler.py [0:0]


  def postprocess(self, features: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Postprocess the image/text featreus for downstream task."""
    if _BIOMED_CLIP_MODEL in self.model_name:
      return features
    preds = []
    if self.task == _FEATURE_EMBEDDING:
      for item in features:
        preds.append({k: v.tolist() for k, v in item.items()})
    elif self.task == _ZERO_CLASSIFICATION:
      for item in features:
        image_features = item.get(_IMAGE_FEATURES_KEY, None)
        text_features = item.get(_TEXT_FEATURES_KEY, None)
        if image_features is None or text_features is None:
          raise ValueError(
              "Missing input for {} task. {} received.".format(
                  _ZERO_CLASSIFICATION, item.keys()
              )
          )
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        preds.append(text_probs.tolist())

    return preds