def postprocess()

in community-content/vertex_model_garden/model_oss/imagebind/handler.py [0:0]


  def postprocess(self, output_list: List[Dict[str, Any]]) -> List[Any]:
    """Postprocesses model outputs for the task of interest.

    For feature embedding generation, returns the embeddings for each modality
    for each input.
    For zero-shot classification, generates classification probabilities
    between the inputs of a pair of modalities for all possible pairings.

    Args:
      output_list: A list of model outputs, with each output being a dictionary
        of modality (key): embedding (value) pairs.

    Returns:
      A list of postprocessed model outputs for the task of interest, with each
      output corresponding to an input.

    Raises:
      ValueError: Fewer than two modalities are provided for zero-shot
      classification, or the task is not supported.
    """
    preds = []
    if self.task == constants.FEATURE_EMBEDDING_GENERATION:
      for item in output_list:
        preds.append({k: v.tolist() for k, v in item.items()})
    elif self.task == constants.ZERO_SHOT_CLASSIFICATION:
      for item in output_list:
        modalities = list(item.keys())
        if len(modalities) < 2:
          raise ValueError(
              "Two or more modalities are needed for task"
              f" {constants.ZERO_SHOT_CLASSIFICATION}."
          )
        pairwise_probs = {}
        for m1 in modalities:
          for m2 in modalities:
            if m1 == m2:
              continue
            probs = torch.softmax(item[m1] @ item[m2].T, dim=-1)
            pairwise_probs[
                f"Classify each input in {m1} (row) against inputs in"
                f" {m2} (column)"
            ] = probs.tolist()
        preds.append(pairwise_probs)
    else:
      raise ValueError(f"Task {self.task} is not supported by the handler.")
    return preds