optimum/exporters/neuron/model_wrappers.py (567 lines of code) (raw):

# coding=utf-8 # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Model wrappers for Neuron export.""" from typing import TYPE_CHECKING, List, Optional import torch from transformers.models.t5.modeling_t5 import T5LayerCrossAttention from ...neuron.utils import is_neuronx_available, is_neuronx_distributed_available if is_neuronx_available(): import torch_xla.core.xla_model as xm if is_neuronx_distributed_available(): import neuronx_distributed if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel class UnetNeuronWrapper(torch.nn.Module): def __init__(self, model, input_names: List[str], device: Optional[str] = None): super().__init__() self.model = model self.input_names = input_names self.device = device def forward(self, *inputs): if len(inputs) != len(self.input_names): raise ValueError( f"The model needs {len(self.input_names)} inputs: {self.input_names}." f" But only {len(inputs)} inputs are passed." ) ordered_inputs = dict(zip(self.input_names, inputs)) added_cond_kwargs = { "text_embeds": ordered_inputs.pop("text_embeds", None), "time_ids": ordered_inputs.pop("time_ids", None), "image_embeds": ordered_inputs.pop("image_embeds", None) or ordered_inputs.pop("image_enc_hidden_states", None), } sample = ordered_inputs.pop("sample", None) timestep = ordered_inputs.pop("timestep").float().expand((sample.shape[0],)) encoder_hidden_states = ordered_inputs.pop("encoder_hidden_states", None) # Re-build down_block_additional_residual down_block_additional_residuals = () down_block_additional_residuals_names = [ name for name in ordered_inputs.keys() if "down_block_additional_residuals" in name ] for name in down_block_additional_residuals_names: value = ordered_inputs.pop(name) down_block_additional_residuals += (value,) mid_block_additional_residual = ordered_inputs.pop("mid_block_additional_residual", None) out_tuple = self.model( sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states, down_block_additional_residuals=( down_block_additional_residuals if down_block_additional_residuals else None ), mid_block_additional_residual=mid_block_additional_residual, added_cond_kwargs=added_cond_kwargs, return_dict=False, ) return out_tuple class PixartTransformerNeuronWrapper(torch.nn.Module): def __init__(self, model, input_names: List[str], device: str = None): super().__init__() self.model = model self.dtype = model.dtype self.input_names = input_names self.device = device def forward(self, *inputs): if len(inputs) != len(self.input_names): raise ValueError( f"The model needs {len(self.input_names)} inputs: {self.input_names}." f" But only {len(input)} inputs are passed." ) ordered_inputs = dict(zip(self.input_names, inputs)) sample = ordered_inputs.pop("sample", None) encoder_hidden_states = ordered_inputs.pop("encoder_hidden_states", None) timestep = ordered_inputs.pop("timestep", None) encoder_attention_mask = ordered_inputs.pop("encoder_attention_mask", None) # Additional conditions out_tuple = self.model( hidden_states=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, added_cond_kwargs={"resolution": None, "aspect_ratio": None}, return_dict=False, ) return out_tuple class ControlNetNeuronWrapper(torch.nn.Module): def __init__(self, model, input_names: List[str], device: str = None): super().__init__() self.model = model self.input_names = input_names self.device = device def forward(self, *inputs): if len(inputs) != len(self.input_names): raise ValueError( f"The model needs {len(self.input_names)} inputs: {self.input_names}." f" But only {len(input)} inputs are passed." ) ordered_inputs = dict(zip(self.input_names, inputs)) sample = ordered_inputs.pop("sample", None) timestep = ordered_inputs.pop("timestep", None) encoder_hidden_states = ordered_inputs.pop("encoder_hidden_states", None) controlnet_cond = ordered_inputs.pop("controlnet_cond", None) conditioning_scale = ordered_inputs.pop("conditioning_scale", None) # Additional conditions for the Stable Diffusion XL UNet. added_cond_kwargs = { "text_embeds": ordered_inputs.pop("text_embeds", None), "time_ids": ordered_inputs.pop("time_ids", None), } out_tuple = self.model( sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states, controlnet_cond=controlnet_cond, conditioning_scale=conditioning_scale, added_cond_kwargs=added_cond_kwargs, guess_mode=False, # TODO: support guess mode of ControlNet return_dict=False, **ordered_inputs, ) return out_tuple # Adapted from https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/inference/hf_pretrained_pixart_alpha_inference_on_inf2.ipynb # For text encoding class T5EncoderWrapper(torch.nn.Module): def __init__( self, model: "PreTrainedModel", sequence_length: int, batch_size: Optional[int] = None, device: str = "cpu" ): super().__init__() self.model = model self.config = model.config self.sequence_length = sequence_length self.batch_size = batch_size self.device = device for block in self.model.encoder.block: block.layer[1].DenseReluDense.act = torch.nn.GELU(approximate="tanh") precomputed_bias = ( self.model.encoder.block[0].layer[0].SelfAttention.compute_bias(self.sequence_length, self.sequence_length) ) self.model.encoder.block[0].layer[0].SelfAttention.compute_bias = lambda *args, **kwargs: precomputed_bias def forward(self, input_ids, attention_mask): return self.model(input_ids, attention_mask=attention_mask) # Adapted from https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/torch-neuronx/t5-inference-tutorial.html # For text encoding + KV cache initialization class T5EncoderForSeq2SeqLMWrapper(torch.nn.Module): """Wrapper to trace the encoder and the kv cache initialization in the decoder.""" def __init__( self, model: "PreTrainedModel", sequence_length: Optional[int] = None, batch_size: Optional[int] = None, num_beams: int = 1, device: str = "xla", tensor_parallel_size: int = 1, ): super().__init__() self.model = model self.config = model.config self.num_beams = num_beams self.sequence_length = sequence_length self.batch_size = batch_size self.device = device self.tensor_parallel_size = tensor_parallel_size self.num_attention_heads_per_partition = self.config.num_heads # when tensor_parallel_size=1 if self.tensor_parallel_size > 1: self.num_attention_heads_per_partition = ( self.num_attention_heads_per_partition // neuronx_distributed.parallel_layers.parallel_state.get_tensor_model_parallel_size() ) self.past_key_values_sa = torch.nn.ParameterList( [ torch.nn.Parameter( torch.ones( ( self.num_beams * batch_size, self.num_attention_heads_per_partition, self.sequence_length - 1, self.config.d_kv, ), dtype=torch.float32, ), requires_grad=False, ) for _ in range(self.config.num_decoder_layers * 2) ] ) self.past_key_values_ca = torch.nn.ParameterList( [ torch.nn.Parameter( torch.ones( ( self.num_beams * batch_size, self.num_attention_heads_per_partition, self.sequence_length, self.config.d_kv, ), dtype=torch.float32, ), requires_grad=False, ) for _ in range(self.config.num_decoder_layers * 2) ] ) def forward(self, input_ids, attention_mask): # Infer shapes of dummy inputs used for tracing batch_size = input_ids.shape[0] sequence_length = input_ids.shape[1] if self.sequence_length is not None: assert self.sequence_length, ( f"Different sequence length for the parallel partition({self.sequence_length}) and for dummy inputs({sequence_length}). Make sure that they have the same value." ) if self.batch_size is not None: assert self.batch_size, ( f"Different batch size for the parallel partition({self.batch_size}) and for dummy inputs({batch_size}). Make sure that they have the same value." ) encoder_output = self.model.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, output_hidden_states=False ) last_hidden_state = encoder_output["last_hidden_state"] encoder_hidden_states = torch.concat( [tensor.unsqueeze(0).repeat(self.num_beams, 1, 1) for tensor in last_hidden_state] ) decoder_blocks = self.model.decoder.block present_key_value_states_sa = [] present_key_value_states_ca = [] for i, block in enumerate(decoder_blocks): # Cross attention has to be initialized with the encoder hidden state cross_attention: T5LayerCrossAttention = block.layer[1] attention = cross_attention.EncDecAttention def shape(states): """projection""" return states.view( self.num_beams * batch_size, -1, self.num_attention_heads_per_partition, attention.key_value_proj_dim, ).transpose(1, 2) key_states = shape(attention.k(encoder_hidden_states)) value_states = shape(attention.v(encoder_hidden_states)) if not self.tensor_parallel_size > 1: # cross_attn_kv_state present_key_value_states_ca.append(key_states) present_key_value_states_ca.append(value_states) # Self attention kv states are initialized to zeros. This is done to keep the size of the kv cache tensor constant. # The kv cache is padded here to keep a fixed shape. # [key states] present_key_value_states_sa.append( torch.zeros( (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv), dtype=torch.float32, device=self.device, ) ) # [value states] present_key_value_states_sa.append( torch.zeros( (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv), dtype=torch.float32, device=self.device, ) ) else: present_key_value_states_ca.append((self.past_key_values_ca[i * 2] * 0) + key_states) present_key_value_states_ca.append((self.past_key_values_ca[i * 2 + 1] * 0) + value_states) present_key_value_states_sa.append( self.past_key_values_sa[i * 2] * torch.zeros( ( self.num_beams * self.batch_size, self.num_attention_heads_per_partition, self.sequence_length - 1, self.config.d_kv, ), dtype=torch.float32, device=self.device, ) ) present_key_value_states_sa.append( self.past_key_values_sa[i * 2 + 1] * torch.zeros( ( self.num_beams * self.batch_size, self.num_attention_heads_per_partition, self.sequence_length - 1, self.config.d_kv, ), dtype=torch.float32, device=self.device, ) ) return present_key_value_states_sa + present_key_value_states_ca # Adapted from https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/torch-neuronx/t5-inference-tutorial.html class T5DecoderWrapper(torch.nn.Module): """Wrapper to trace the decoder with past keys values with a language head.""" def __init__( self, model: "PreTrainedModel", batch_size: int, sequence_length: int, num_beams: int = 1, output_hidden_states: bool = False, output_attentions: bool = False, device: str = "xla", tensor_parallel_size: int = 1, ): super().__init__() self.model = model self.config = model.config self.batch_size = batch_size self.sequence_length = sequence_length self.num_beams = num_beams self.output_hidden_states = output_hidden_states self.output_attentions = output_attentions self.device = device self.tensor_parallel_size = tensor_parallel_size self.num_attention_heads_per_partition = self.config.num_heads if tensor_parallel_size > 1: self.num_attention_heads_per_partition = ( self.num_attention_heads_per_partition // neuronx_distributed.parallel_layers.parallel_state.get_tensor_model_parallel_size() ) # Initialize KV cache (num_beams, n_heads, seq_length, dim_per_head) if device == "cpu": self.past_key_values_sa = [ torch.ones( (num_beams, self.config.num_heads, self.sequence_length - 1, self.config.d_kv), dtype=torch.float32 ) for _ in range(self.config.num_decoder_layers * 2) ] self.past_key_values_ca = [ torch.ones( (num_beams, self.config.num_heads, self.sequence_length, self.config.d_kv), dtype=torch.float32 ) for _ in range(self.config.num_decoder_layers * 2) ] elif device == "xla": self.past_key_values_sa = torch.nn.ParameterList( [ torch.nn.Parameter( torch.ones( ( self.batch_size * self.num_beams, self.num_attention_heads_per_partition, sequence_length - 1, self.config.d_kv, ), dtype=torch.float32, ), requires_grad=False, ) for _ in range(self.config.num_decoder_layers * 2) ] ) self.past_key_values_ca = torch.nn.ParameterList( [ torch.nn.Parameter( torch.ones( ( self.batch_size * self.num_beams, self.num_attention_heads_per_partition, sequence_length, self.config.d_kv, ), dtype=torch.float32, ), requires_grad=False, ) for _ in range(self.config.num_decoder_layers * 2) ] ) def update_past(self, past_key_values): new_past_sa = [] new_past_ca = [] for past_layer in past_key_values: new_past_layer = list(past_layer) for i in range(len(new_past_layer[:2])): new_past_layer[i] = past_layer[i][:, :, 1:] new_past_sa += [ new_past_layer[:2], ] new_past_ca += [ new_past_layer[2:], ] return new_past_sa, new_past_ca def reorder_cache(self, past_key_values, beam_idx): for i in range(len(past_key_values)): gather_index = beam_idx.view([beam_idx.shape[0], 1, 1, 1]).expand_as(past_key_values[i]) past_key_values[i] = torch.gather(past_key_values[i], dim=0, index=gather_index) return past_key_values def forward( self, input_ids, decoder_attention_mask, encoder_hidden_states, encoder_attention_mask, beam_idx, beam_scores, **kwargs, ): if self.num_beams > 1: # We reorder the cache based on the beams selected in each iteration. Required step for beam search. past_key_values_sa = self.reorder_cache(self.past_key_values_sa, beam_idx) past_key_values_ca = self.reorder_cache(self.past_key_values_ca, beam_idx) else: # We do not need to reorder for greedy sampling past_key_values_sa = self.past_key_values_sa past_key_values_ca = self.past_key_values_ca # The cache is stored in a flatten form. We order the cache per layer before passing it to the decoder. # Each layer has 4 tensors, so we group by 4. past_key_values = [ [*past_key_values_sa[i * 2 : i * 2 + 2], *past_key_values_ca[i * 2 : i * 2 + 2]] for i in range(0, int(len(past_key_values_ca) / 2)) ] decoder_output = self.model.decoder( input_ids=input_ids, attention_mask=decoder_attention_mask, past_key_values=past_key_values, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=True, output_attentions=self.output_attentions, output_hidden_states=self.output_hidden_states, ) last_hidden_state = decoder_output["last_hidden_state"] past_key_values = decoder_output["past_key_values"] if self.output_hidden_states: decoder_hidden_states = list( decoder_output["hidden_states"] ) # flatten `hidden_states` which is a tuple of tensors if self.output_attentions: decoder_attentions = list( decoder_output["attentions"] ) # flatten `decoder_attentions` which is a tuple of tensors cross_attentions = list( decoder_output["cross_attentions"] ) # flatten `cross_attentions` which is a tuple of tensors if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 last_hidden_state = last_hidden_state * (self.model.config.d_model**-0.5) lm_logits = self.model.lm_head(last_hidden_state) past_key_values_sa, past_key_values_ca = self.update_past(past_key_values) # We flatten the cache to a single array. This is required for the input output aliasing to work past_key_values_sa = [vec for kv_per_layer in past_key_values_sa for vec in kv_per_layer] past_key_values_ca = [vec for kv_per_layer in past_key_values_ca for vec in kv_per_layer] if self.device == "cpu": self.past_key_values_sa = past_key_values_sa self.past_key_values_ca = past_key_values_ca # We calculate topk inside the wrapper next_token_logits = lm_logits[:, -1, :] if self.num_beams > 1: # This section of beam search is run outside the decoder in the huggingface t5 implementation. # To maximize the computation within the neuron device, we move this within the wrapper logit_max, _ = torch.max(next_token_logits, dim=-1, keepdim=True) logsumexp = torch.log(torch.exp(next_token_logits - logit_max).sum(dim=-1, keepdim=True)) next_token_scores = next_token_logits - logit_max - logsumexp next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(self.batch_size, self.num_beams * vocab_size) next_token_scores = next_token_scores * 1 # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * self.num_beams, dim=1, largest=True, sorted=True ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size neuron_outputs = [next_token_scores, next_tokens, next_indices] + past_key_values_sa + past_key_values_ca else: # Greedy next_tokens = torch.argmax(next_token_logits, dim=-1) neuron_outputs = [next_tokens] + past_key_values_sa + past_key_values_ca if self.output_hidden_states: neuron_outputs += decoder_hidden_states if self.output_attentions: neuron_outputs += decoder_attentions neuron_outputs += cross_attentions return neuron_outputs class SentenceTransformersTransformerNeuronWrapper(torch.nn.Module): def __init__(self, model, input_names: List[str], device: str = None): super().__init__() self.model = model self.input_names = input_names self.device = device def forward(self, input_ids, attention_mask): out_tuple = self.model({"input_ids": input_ids, "attention_mask": attention_mask}) return out_tuple["token_embeddings"], out_tuple["sentence_embedding"] class CLIPVisionWithProjectionNeuronWrapper(torch.nn.Module): def __init__( self, model, input_names: List[str], output_hidden_states: bool = True, device: str = None, ): super().__init__() self.model = model self.input_names = input_names self.output_hidden_states = output_hidden_states self.device = device def forward(self, pixel_values): vision_outputs = self.model.vision_model( pixel_values=pixel_values, output_hidden_states=self.output_hidden_states ) pooled_output = vision_outputs[1] image_embeds = self.model.visual_projection(pooled_output) outputs = (image_embeds, vision_outputs.last_hidden_state) if self.output_hidden_states: outputs += (vision_outputs.hidden_states,) return outputs class SentenceTransformersCLIPNeuronWrapper(torch.nn.Module): def __init__(self, model, input_names: List[str], device: str = None): super().__init__() self.model = model self.input_names = input_names self.device = device def forward(self, input_ids, pixel_values, attention_mask): vision_outputs = self.model[0].model.vision_model(pixel_values=pixel_values) image_embeds = self.model[0].model.visual_projection(vision_outputs[1]) text_outputs = self.model[0].model.text_model( input_ids=input_ids, attention_mask=attention_mask, ) text_embeds = self.model[0].model.text_projection(text_outputs[1]) if len(self.model) > 1: image_embeds = self.model[1:](image_embeds) text_embeds = self.model[1:](text_embeds) return (text_embeds, image_embeds) class WhisperEncoderWrapper(torch.nn.Module): """Wrapper to trace the forward of Whisper encoder.""" def __init__( self, model: "PreTrainedModel", batch_size: int, device: str = None, **kwargs, ): super().__init__() self.model = model self.config = model.config self.batch_size = batch_size self.device = device def forward( self, input_features, decoder_input_ids, **kwargs, ): # encoder encoder_outputs = self.model.model.encoder( input_features=input_features, return_dict=True, ) # 1st decoder + proj_out decoder_outputs = self.model.model.decoder( input_ids=decoder_input_ids, encoder_hidden_states=encoder_outputs[0], use_cache=False, return_dict=True, ) lm_logits = self.model.proj_out(decoder_outputs[0]) return (lm_logits, encoder_outputs.last_hidden_state) class WhisperDecoderWrapper(torch.nn.Module): """Wrapper to trace the forward of Whisper decoder.""" def __init__( self, model: "PreTrainedModel", batch_size: int, sequence_length: int, output_hidden_states: bool = False, output_attentions: bool = False, device: str = None, **kwargs, ): super().__init__() self.model = model self.config = model.config self.batch_size = batch_size self.sequence_length = sequence_length self.output_hidden_states = output_hidden_states self.output_attentions = output_attentions self.device = device if device else xm.xla_device() def forward( self, input_ids, encoder_hidden_states, **kwargs, ): cache_position = torch.arange(input_ids.shape[1]).to(self.device) outputs = self.model.model.decoder( input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, use_cache=False, output_attentions=False, output_hidden_states=False, return_dict=True, cache_position=cache_position, ) lm_logits = self.model.proj_out(outputs[0]) return lm_logits class NoCacheModelWrapper(torch.nn.Module): def __init__(self, model: "PreTrainedModel", input_names: List[str]): super().__init__() self.model = model self.input_names = input_names def forward(self, *input): ordered_inputs = dict(zip(self.input_names, input)) outputs = self.model(use_cache=False, **ordered_inputs) return outputs