docker_images/open_clip/app/pipelines/zero_shot_image_classification.py (64 lines of code) (raw):
import json
from typing import Any, Dict, List, Optional
import open_clip
import torch
import torch.nn.functional as F
from app.pipelines import Pipeline
from open_clip.pretrained import download_pretrained_from_hf
from PIL import Image
class ZeroShotImageClassificationPipeline(Pipeline):
def __init__(self, model_id: str):
self.model, self.preprocess = open_clip.create_model_from_pretrained(
f"hf-hub:{model_id}"
)
config_path = download_pretrained_from_hf(
model_id,
filename="open_clip_config.json",
)
with open(config_path, "r", encoding="utf-8") as f:
# TODO grab custom prompt templates from preprocess_cfg
self.config = json.load(f)
self.tokenizer = open_clip.get_tokenizer(f"hf-hub:{model_id}")
self.model.eval()
self.use_sigmoid = getattr(self.model, "logit_bias", None) is not None
def __call__(
self,
inputs: Image.Image,
candidate_labels: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""
Args:
inputs (:obj:`PIL.Image`):
The raw image representation as PIL.
No transformation made whatsoever from the input. Make all necessary transformations here.
candidate_labels (List[str]):
A list of strings representing candidate class labels.
Return:
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
It is preferred if the returned list is in decreasing `score` order
"""
if candidate_labels is None:
raise ValueError("'candidate_labels' is a required field")
if isinstance(candidate_labels, str):
candidate_labels = candidate_labels.split(",")
prompt_templates = (
"a bad photo of a {}.",
"a photo of the large {}.",
"art of the {}.",
"a photo of the small {}.",
"this is an image of {}.",
)
image = inputs.convert("RGB")
image_inputs = self.preprocess(image).unsqueeze(0)
classifier = open_clip.build_zero_shot_classifier(
self.model,
tokenizer=self.tokenizer,
classnames=candidate_labels,
templates=prompt_templates,
num_classes_per_batch=10,
)
with torch.no_grad():
image_features = self.model.encode_image(image_inputs)
image_features = F.normalize(image_features, dim=-1)
logits = image_features @ classifier * self.model.logit_scale.exp()
if self.use_sigmoid:
logits += self.model.logit_bias
scores = torch.sigmoid(logits.squeeze(0))
else:
scores = logits.squeeze(0).softmax(0)
output = [
{
"label": l,
"score": s.item(),
}
for l, s in zip(candidate_labels, scores)
]
return output