optimum/graphcore/models/deberta/modeling_deberta.py (348 lines of code) (raw):

# Copyright (c) 2022 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import poptorch import torch import torch.nn as nn import torch.nn.functional as F from transformers import ( DebertaForMaskedLM, DebertaForQuestionAnswering, DebertaForSequenceClassification, DebertaForTokenClassification, ) from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput from transformers.models.deberta.modeling_deberta import ( DebertaEncoder, DisentangledSelfAttention, StableDropout, build_relative_position, ) from optimum.utils import logging from ...modeling_utils import ( OnehotGather, PipelineMixin, SerializedEmbedding, SerializedLinear, get_layer_ipu, outline_attribute, recomputation_checkpoint, register, ) logger = logging.get_logger(__name__) class FastGatherLastDim(nn.Module): """ Custom Op that does a faster specialised version of `gather` on the last dimension of a tensor. """ def __init__(self): super().__init__() def forward(self, data, idx, target=None): if poptorch.isRunningOnIpu(): if target is None: target = torch.zeros_like(idx).to(data.dtype) else: target = target.type_as(data) target.requires_grad_() o = poptorch.custom_op( [data, idx], "FastGatherLastDim", "poptorch.custom_ops", 1, example_outputs=[target], attributes={"axis": -1}, ) return o[0] else: return torch.gather(data, -1, idx) gather_last_dim = FastGatherLastDim() class XSoftmax(torch.nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, input, mask): """ """ rmask = ~(mask.bool()) output = self.masked_fill_approx(input, rmask, -100000) output = torch.softmax(output, self.dim) output = self.masked_fill_approx(output, rmask, 0) return output def masked_fill_approx(self, input, mask, value): mask_int = mask.to(torch.int) mask_ = value * mask_int output = input + mask_ return output def _get_rel_embedding(self): return self.rel_embeddings.weight + 0.0 if self.relative_attention else None class IPUDisentangledSelfAttention(DisentangledSelfAttention): """ Disentangled self-attention module Parameters: config (`str`): A model config class instance with the configuration to build a new model. The schema is similar to *BertConfig*, for more details, please refer [`DebertaConfig`] """ def __init__(self, config): super().__init__(config) self.xsoftmax = XSoftmax(-1) def forward( self, hidden_states, attention_mask, output_attentions=False, query_states=None, relative_pos=None, rel_embeddings=None, ): """ Call the module Args: hidden_states (`torch.FloatTensor`): Input states to the module usually the output from previous layer, it will be the Q,K and V in *Attention(Q,K,V)* attention_mask (`torch.ByteTensor`): An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* th token. output_attentions (`bool`, optional): Whether return the attention matrix. query_states (`torch.FloatTensor`, optional): The *Q* state in *Attention(Q,K,V)*. relative_pos (`torch.LongTensor`): The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with values ranging in [*-max_relative_positions*, *max_relative_positions*]. rel_embeddings (`torch.FloatTensor`): The embedding of relative distances. It's a tensor of shape [\\(2 \\times \\text{max_relative_positions}\\), *hidden_size*]. """ if query_states is None: qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1) else: def linear(w, b, x): if b is not None: return torch.matmul(x, w.t()) + b.t() else: return torch.matmul(x, w.t()) # + b.t() ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0) qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] qkvb = [None] * 3 q = linear(qkvw[0], qkvb[0], query_states.to(dtype=qkvw[0].dtype)) k, v = [linear(qkvw[i], qkvb[i], hidden_states.to(dtype=qkvw[i].dtype)) for i in range(1, 3)] query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :]) rel_att = None # Take the dot product between "query" and "key" to get the raw attention scores. scale_factor = 1 + len(self.pos_att_type) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) query_layer = query_layer / scale.to(dtype=query_layer.dtype) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) if rel_att is not None: attention_scores = attention_scores + rel_att # bxhxlxd if self.talking_head: attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) attention_probs = self.xsoftmax(attention_scores, attention_mask) attention_probs = self.dropout(attention_probs) if self.talking_head: attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(*new_context_layer_shape) if output_attentions: return (context_layer, attention_probs) else: return context_layer def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): if relative_pos is None: q = query_layer.size(-2) relative_pos = build_relative_position(q, key_layer.size(-2), query_layer.device) if relative_pos.dim() == 2: relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) elif relative_pos.dim() == 3: relative_pos = relative_pos.unsqueeze(1) # bxhxqxk elif relative_pos.dim() != 4: raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") att_span = min(max(query_layer.size(-2), key_layer.size(-2)), self.max_relative_positions) relative_pos = relative_pos.long().to(query_layer.device) rel_embeddings = rel_embeddings[ self.max_relative_positions - att_span : self.max_relative_positions + att_span, : ].unsqueeze(0) score = 0 # content->position if "c2p" in self.pos_att_type: pos_key_layer = self.pos_proj(rel_embeddings) pos_key_layer = self.transpose_for_scores(pos_key_layer) c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) index = c2p_pos.expand( [query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)] ) c2p_att = gather_last_dim(c2p_att, index) score += c2p_att # position->content if "p2c" in self.pos_att_type: pos_query_layer = self.pos_q_proj(rel_embeddings) pos_query_layer = self.transpose_for_scores(pos_query_layer) pos_query_layer /= torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor) if query_layer.size(-2) != key_layer.size(-2): r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device) else: r_pos = relative_pos p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) index = p2c_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2).to(dtype=key_layer.dtype)) p2c_att = gather_last_dim(p2c_att, index).transpose(-1, -2) if query_layer.size(-2) != key_layer.size(-2): pos_index = relative_pos[:, :, :, 0].unsqueeze(-1) index = pos_index.expand(pos_index, p2c_att, key_layer) p2c_att = gather_last_dim(p2c_att, index) score += p2c_att return score class DebertaPipelineMixin(PipelineMixin): def change_modules_for_ipu(self, restore: bool): for mod in self.modules(): if isinstance(mod, DisentangledSelfAttention): mod.__class__ = DisentangledSelfAttention if restore else IPUDisentangledSelfAttention if restore: del mod.xsoftmax else: mod.xsoftmax = XSoftmax(-1) if restore: if isinstance(mod, nn.Dropout): mod.__class__ = StableDropout mod.drop_prob = mod.p mod.count = 0 mod.context_stack = None else: if isinstance(mod, StableDropout): mod.__class__ = nn.Dropout mod.p = mod.drop_prob mod.inplace = False if isinstance(mod, DebertaEncoder): func = DebertaEncoder.get_rel_embedding if restore else _get_rel_embedding mod.get_rel_embedding = func.__get__(mod, DebertaEncoder) def parallelize(self): """ Transform the model to run in an IPU pipeline. - Adds pipeline stages to the model - (If enabled) Replaces the word embedding with a SerializedEmbedding - Replaces several modules with IPU compatible counterparts - Adds recomputation checkpoints """ self._hooks = [] logger.info("-------------------- Device Allocation --------------------") logger.info("Embedding --> IPU 0") if self.ipu_config.embedding_serialization_factor > 1: if isinstance(self, PipelinedDebertaForMaskedLM): self.cls.predictions.decoder = SerializedLinear.from_model( self.cls.predictions.decoder, self.ipu_config.embedding_serialization_factor ) self.tie_weights() else: self.deberta.embeddings.word_embeddings = SerializedEmbedding.from_model( self.deberta.embeddings.word_embeddings, self.ipu_config.embedding_serialization_factor ) self.change_modules_for_ipu(False) self.deberta.embeddings = poptorch.BeginBlock(self.deberta.embeddings, "Embedding", ipu_id=0) hs = outline_attribute(self.deberta.embeddings.LayerNorm, "embedding") self._hooks.extend(hs) self.deberta.encoder = poptorch.BeginBlock(self.deberta.encoder, ipu_id=0) if self.deberta.encoder.relative_attention: self.deberta.encoder.rel_embeddings = poptorch.BeginBlock(self.deberta.encoder.rel_embeddings, ipu_id=0) layer_ipu = get_layer_ipu(self.ipu_config, self.deberta.encoder.layer) for index, layer in enumerate(self.deberta.encoder.layer): ipu = layer_ipu[index] if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: h = recomputation_checkpoint(layer) self._hooks.append(h) self.deberta.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) logger.info(f"Encoder {index:<2} --> IPU {ipu}") if isinstance(self, PipelinedDebertaForMaskedLM): logger.info(f"Projection {index:<2} --> IPU {0}") self.cls.predictions.decoder = poptorch.BeginBlock(self.cls.predictions.decoder, "Projection", ipu_id=0) return self def deparallelize(self): """ Undo the changes to the model done by `parallelize`. You should call this before doing `save_pretrained` so that the `model.state_dict` is compatible with the original model. """ super().deparallelize() self.change_modules_for_ipu(True) if self.ipu_config.embedding_serialization_factor > 1: if isinstance(self.cls.predictions.decoder, SerializedLinear): self.cls.predictions.decoder = self.cls.predictions.decoder.to_model() self.tie_weights() else: # Deserialize the serialized word embedding self.deberta.embeddings.word_embeddings = self.deberta.embeddings.word_embeddings.to_model() return self @register(DebertaForMaskedLM) class PipelinedDebertaForMaskedLM(DebertaForMaskedLM, DebertaPipelineMixin): """ DebertaForMaskedLM transformed to run in an IPU pipeline. Recommended usage: ``` model = PipelinedDebertaForMaskedLM(config).parallelize().half() ``` """ def __init__(self, config): super().__init__(config) self.gather_indices = OnehotGather() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.deberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] if labels is not None: # Select only the masked tokens for the classifier max_number_of_masked_tokens = int(labels.size(1) * 0.25) masked_lm_labels, masked_lm_positions = torch.topk(labels, k=max_number_of_masked_tokens, dim=1) masked_output = self.gather_indices(sequence_output, masked_lm_positions) else: # This case should never happen during training masked_output = sequence_output prediction_scores = self.cls(masked_output) masked_lm_loss = None if labels is not None: masked_lm_loss = F.cross_entropy( prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1) ).float() if not return_dict: output = (prediction_scores,) + outputs[1:] return ((masked_lm_loss,)) if masked_lm_loss is not None else output return MaskedLMOutput( loss=masked_lm_loss, logits=prediction_scores if masked_lm_loss is None else None, hidden_states=outputs.hidden_states if masked_lm_loss is None else None, attentions=outputs.attentions if masked_lm_loss is None else None, ) @register(DebertaForSequenceClassification) class PipelinedDebertaForSequenceClassification(DebertaForSequenceClassification, DebertaPipelineMixin): """ DebertaForSequenceClassification transformed to run in an IPU pipeline. Recommended usage: ``` model = PipelinedDebertaForSequenceClassification(config).parallelize().half() ``` """ def parallelize(self): super().parallelize() last_ipu = self.ipu_config._ipus_per_replica - 1 logger.info(f"Classifier Output --> IPU {last_ipu}") self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu) logger.info("-----------------------------------------------------------") return self @register(DebertaForTokenClassification) class PipelinedDebertaForTokenClassification(DebertaForTokenClassification, DebertaPipelineMixin): """ DebertaForTokenClassification transformed to run in an IPU pipeline. Recommended usage: ``` model = PipelinedDebertaForTokenClassification(config).parallelize().half() ``` """ def parallelize(self): super().parallelize() last_ipu = self.ipu_config._ipus_per_replica - 1 logger.info(f"Classifier Output --> IPU {last_ipu}") self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu) logger.info("-----------------------------------------------------------") return self def deparallelize(self): super().deparallelize() # Last dropout isn't a StableDropout so undo its replacement # made by change_modules_for_ipu mod = self.dropout if isinstance(mod, StableDropout): mod.__class__ = nn.Dropout mod.p = mod.drop_prob mod.inplace = False @register(DebertaForQuestionAnswering) class PipelinedDebertaForQuestionAnswering(DebertaForQuestionAnswering, DebertaPipelineMixin): """ DebertaForQuestionAnswering transformed to run in an IPU pipeline. Recommended usage: ``` model = PipelinedDebertaForQuestionAnswering(config).parallelize().half() ``` """ def parallelize(self): super().parallelize() last_ipu = self.ipu_config._ipus_per_replica - 1 logger.info(f"QA Outputs --> IPU {last_ipu}") self.qa_outputs = poptorch.BeginBlock(self.qa_outputs, "QA Outputs", ipu_id=last_ipu) logger.info("-----------------------------------------------------------") return self def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, start_positions: Optional[torch.Tensor] = None, end_positions: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # return_dict = False output = super().forward( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, start_positions=start_positions, end_positions=end_positions, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if start_positions is not None and end_positions is not None: output = (poptorch.identity_loss(output[0], reduction="none"),) + output[1:] return output