in community-content/vertex_model_garden/model_oss/imagebind/handler.py [0:0]
def postprocess(self, output_list: List[Dict[str, Any]]) -> List[Any]:
"""Postprocesses model outputs for the task of interest.
For feature embedding generation, returns the embeddings for each modality
for each input.
For zero-shot classification, generates classification probabilities
between the inputs of a pair of modalities for all possible pairings.
Args:
output_list: A list of model outputs, with each output being a dictionary
of modality (key): embedding (value) pairs.
Returns:
A list of postprocessed model outputs for the task of interest, with each
output corresponding to an input.
Raises:
ValueError: Fewer than two modalities are provided for zero-shot
classification, or the task is not supported.
"""
preds = []
if self.task == constants.FEATURE_EMBEDDING_GENERATION:
for item in output_list:
preds.append({k: v.tolist() for k, v in item.items()})
elif self.task == constants.ZERO_SHOT_CLASSIFICATION:
for item in output_list:
modalities = list(item.keys())
if len(modalities) < 2:
raise ValueError(
"Two or more modalities are needed for task"
f" {constants.ZERO_SHOT_CLASSIFICATION}."
)
pairwise_probs = {}
for m1 in modalities:
for m2 in modalities:
if m1 == m2:
continue
probs = torch.softmax(item[m1] @ item[m2].T, dim=-1)
pairwise_probs[
f"Classify each input in {m1} (row) against inputs in"
f" {m2} (column)"
] = probs.tolist()
preds.append(pairwise_probs)
else:
raise ValueError(f"Task {self.task} is not supported by the handler.")
return preds