optimum/habana/transformers/models/blip/modeling_blip.py (95 lines of code) (raw):
from typing import Optional
import torch
from transformers.utils import logging
logger = logging.get_logger(__name__)
@torch.no_grad()
def gaudi_BlipForConditionalGeneration_generate(
self,
pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
"""
Copied from BlipForQuestionAnswering.generate: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/blip/modeling_blip.py#L1022
The only differences are:
- wrap hpu graph for each part
"""
if generate_kwargs.get("hpu_graphs", True):
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
if not hasattr(self.vision_model, "clear_cache"):
self.vision_model = wrap_in_hpu_graph(self.vision_model)
if not hasattr(self.text_decoder, "clear_cache"):
self.text_decoder = wrap_in_hpu_graph(self.text_decoder)
batch_size = pixel_values.shape[0]
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
if isinstance(input_ids, list):
input_ids = torch.LongTensor(input_ids)
elif input_ids is None:
input_ids = (
torch.LongTensor([[self.decoder_input_ids, self.config.text_config.eos_token_id]])
.repeat(batch_size, 1)
.to(image_embeds.device)
)
input_ids[:, 0] = self.config.text_config.bos_token_id
attention_mask = attention_mask[:, :-1] if attention_mask is not None else None
outputs = self.text_decoder.generate(
input_ids=input_ids[:, :-1],
eos_token_id=self.config.text_config.sep_token_id,
pad_token_id=self.config.text_config.pad_token_id,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
**generate_kwargs,
)
return outputs
@torch.no_grad()
def gaudi_BlipForQuestionAnswering_generate(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
"""
Copied from BlipForQuestionAnswering.generate: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/blip/modeling_blip.py#L1236
The only differences are:
- wrap hpu graph for each part
- torch.full add dtype=torch.int64, or else the default type is torch.float32. lead to coredump in embeding layer
"""
if generate_kwargs.get("hpu_graphs", True):
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
if not hasattr(self.vision_model, "clear_cache"):
self.vision_model = wrap_in_hpu_graph(self.vision_model)
if not hasattr(self.text_encoder, "clear_cache"):
self.text_encoder = wrap_in_hpu_graph(self.text_encoder)
if not hasattr(self.text_decoder, "clear_cache"):
self.text_decoder = wrap_in_hpu_graph(self.text_decoder)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
if isinstance(input_ids, list):
input_ids = torch.LongTensor(input_ids)
question_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=False,
)
question_embeds = question_outputs[0]
question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long, device=question_embeds.device)
bos_ids = torch.full(
(question_embeds.size(0), 1),
fill_value=self.decoder_start_token_id,
device=question_embeds.device,
dtype=torch.int64,
)
outputs = self.text_decoder.generate(
input_ids=bos_ids,
eos_token_id=self.config.text_config.sep_token_id,
pad_token_id=self.config.text_config.pad_token_id,
encoder_hidden_states=question_embeds,
encoder_attention_mask=question_attention_mask,
**generate_kwargs,
)
return outputs