optimum/graphcore/models/lxmert/modeling_lxmert.py (92 lines of code) (raw):
# Copyright 2018 Hao Tan, Mohit Bansal, and the HuggingFace team
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import poptorch
import torch
import torch.nn.functional as F
from transformers import LxmertForQuestionAnswering
from transformers.models.lxmert.modeling_lxmert import LxmertForQuestionAnsweringOutput
from optimum.utils import logging
from ...modeling_utils import PipelineMixin, recomputation_checkpoint, register
logger = logging.get_logger(__name__)
@register(LxmertForQuestionAnswering)
class PipelinedLxmertForQuestionAnswering(LxmertForQuestionAnswering, PipelineMixin):
def parallelize(self):
"""
Transform the model to run in an IPU pipeline.
- Adds pipeline stages to the model
- Adds recomputation checkpoints
Recommended usage:
```
model = PipelinedLxmertForQuestionAnswering(config).parallelize().half()
```
"""
self._hooks = []
logger.info("-------------------- Device Allocation --------------------")
logger.info("Embedding --> IPU 0")
self.lxmert.embeddings = poptorch.BeginBlock(self.lxmert.embeddings, "Embedding", ipu_id=0)
logger.info("Image embedding --> IPU 0")
self.lxmert.encoder.visn_fc = poptorch.BeginBlock(self.lxmert.encoder.visn_fc, "Image embedding", ipu_id=0)
# Language layers
for index, layer in enumerate(self.lxmert.encoder.layer):
if self.ipu_config.recompute_checkpoint_every_layer:
h = recomputation_checkpoint(layer)
self._hooks.append(h)
self.lxmert.encoder.layer[index] = poptorch.BeginBlock(layer, f"Language layer{index}", ipu_id=1)
logger.info(f"Language layer {index:<2} --> IPU 1")
# Visual layers
for index, layer in enumerate(self.lxmert.encoder.r_layers):
if self.ipu_config.recompute_checkpoint_every_layer:
h = recomputation_checkpoint(layer)
self._hooks.append(h)
self.lxmert.encoder.r_layers[index] = poptorch.BeginBlock(layer, f"Visual layer{index}", ipu_id=2)
logger.info(f"Visual layer {index:<2} --> IPU 2")
# Cross modality layers
for index, layer in enumerate(self.lxmert.encoder.x_layers):
if self.ipu_config.recompute_checkpoint_every_layer:
h = recomputation_checkpoint(layer)
self._hooks.append(h)
self.lxmert.encoder.x_layers[index] = poptorch.BeginBlock(layer, f"Cross modality layer{index}", ipu_id=3)
logger.info(f"Cross modality layer {index:<2} --> IPU 3")
logger.info("Pooler --> IPU 3")
self.lxmert.pooler = poptorch.BeginBlock(self.lxmert.pooler, "Pooler", ipu_id=3)
logger.info("Head --> IPU 3")
self.answer_head = poptorch.BeginBlock(self.answer_head, "Head", ipu_id=3)
logger.info("-----------------------------------------------------------")
return self
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
visual_feats: Optional[torch.FloatTensor] = None,
visual_pos: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
visual_attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[LxmertForQuestionAnsweringOutput, Tuple[torch.FloatTensor]]:
r"""
labels: (`Torch.Tensor` of shape `(batch_size)`, *optional*):
A one-hot representation of the correct answer
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
lxmert_output = self.lxmert(
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
pooled_output = lxmert_output[2]
answer_score = self.answer_head(pooled_output)
loss = None
if labels is not None:
if labels.dim() == 1:
loss = F.cross_entropy(answer_score.view(-1, self.num_qa_labels), labels.view(-1))
# Soft labels for datasets such as VQA v2
else:
loss = F.binary_cross_entropy_with_logits(
answer_score.view(-1, self.num_qa_labels), labels.view(-1, self.num_qa_labels)
)
if not return_dict:
output = (answer_score,) + lxmert_output[3:]
return (loss,) + output if loss is not None else output
return LxmertForQuestionAnsweringOutput(
loss=loss,
question_answering_score=answer_score,
language_hidden_states=lxmert_output.language_hidden_states,
vision_hidden_states=lxmert_output.vision_hidden_states,
language_attentions=lxmert_output.language_attentions,
vision_attentions=lxmert_output.vision_attentions,
cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
)