optimum/habana/transformers/models/owlvit/modeling_owlvit.py (28 lines of code) (raw):

from typing import Optional, Tuple import torch def gaudi_owlvitclasspredictionhead_forward( self, image_embeds: torch.FloatTensor, query_embeds: Optional[torch.FloatTensor], query_mask: Optional[torch.Tensor], ) -> Tuple[torch.FloatTensor]: """ Copied from modeling_owlvit: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/owlvit/modeling_owlvit.py#L1233 The only modification is: - Replace the torch.where by torch.clamp """ image_class_embeds = self.dense0(image_embeds) if query_embeds is None: device = image_class_embeds.device batch_size, num_patches = image_class_embeds.shape[:2] pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device) return (pred_logits, image_class_embeds) # Normalize image and text features image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6) query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6) # Get class predictions pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) # Apply a learnable shift and scale to logits logit_shift = self.logit_shift(image_embeds) logit_scale = self.logit_scale(image_embeds) logit_scale = self.elu(logit_scale) + 1 pred_logits = (pred_logits + logit_shift) * logit_scale if query_mask is not None: if query_mask.ndim > 1: query_mask = torch.unsqueeze(query_mask, dim=-2) pred_logits = pred_logits.to(torch.float64) pred_logits = torch.clamp(pred_logits, min=-1e6) pred_logits = pred_logits.to(torch.float32) return (pred_logits, image_class_embeds)