def compute_image_embedding()

in video_processing/reference_video_similarity.py [0:0]


def compute_image_embedding(image_path, model, preprocessor, device, dtype):
    """
    Computes an embedding for a single image.
    """
    image = Image.open(image_path).convert("RGB")
    image_input = preprocessor(image, return_tensors="pt").to(device)
    with torch.no_grad() and torch.autocast(torch.device(device).type, dtype=dtype):
        embedding = model(**image_input).pooler_output
        embedding = embedding / embedding.norm(dim=-1, keepdim=True)
    return embedding.cpu().numpy().flatten()