docker_images/open_clip/app/pipelines/zero_shot_image_classification.py (64 lines of code) (raw):

import json from typing import Any, Dict, List, Optional import open_clip import torch import torch.nn.functional as F from app.pipelines import Pipeline from open_clip.pretrained import download_pretrained_from_hf from PIL import Image class ZeroShotImageClassificationPipeline(Pipeline): def __init__(self, model_id: str): self.model, self.preprocess = open_clip.create_model_from_pretrained( f"hf-hub:{model_id}" ) config_path = download_pretrained_from_hf( model_id, filename="open_clip_config.json", ) with open(config_path, "r", encoding="utf-8") as f: # TODO grab custom prompt templates from preprocess_cfg self.config = json.load(f) self.tokenizer = open_clip.get_tokenizer(f"hf-hub:{model_id}") self.model.eval() self.use_sigmoid = getattr(self.model, "logit_bias", None) is not None def __call__( self, inputs: Image.Image, candidate_labels: Optional[List[str]] = None, ) -> List[Dict[str, Any]]: """ Args: inputs (:obj:`PIL.Image`): The raw image representation as PIL. No transformation made whatsoever from the input. Make all necessary transformations here. candidate_labels (List[str]): A list of strings representing candidate class labels. Return: A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82} It is preferred if the returned list is in decreasing `score` order """ if candidate_labels is None: raise ValueError("'candidate_labels' is a required field") if isinstance(candidate_labels, str): candidate_labels = candidate_labels.split(",") prompt_templates = ( "a bad photo of a {}.", "a photo of the large {}.", "art of the {}.", "a photo of the small {}.", "this is an image of {}.", ) image = inputs.convert("RGB") image_inputs = self.preprocess(image).unsqueeze(0) classifier = open_clip.build_zero_shot_classifier( self.model, tokenizer=self.tokenizer, classnames=candidate_labels, templates=prompt_templates, num_classes_per_batch=10, ) with torch.no_grad(): image_features = self.model.encode_image(image_inputs) image_features = F.normalize(image_features, dim=-1) logits = image_features @ classifier * self.model.logit_scale.exp() if self.use_sigmoid: logits += self.model.logit_bias scores = torch.sigmoid(logits.squeeze(0)) else: scores = logits.squeeze(0).softmax(0) output = [ { "label": l, "score": s.item(), } for l, s in zip(candidate_labels, scores) ] return output