optimum/exporters/neuron/model_wrappers.py (567 lines of code) (raw):
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model wrappers for Neuron export."""
from typing import TYPE_CHECKING, List, Optional
import torch
from transformers.models.t5.modeling_t5 import T5LayerCrossAttention
from ...neuron.utils import is_neuronx_available, is_neuronx_distributed_available
if is_neuronx_available():
import torch_xla.core.xla_model as xm
if is_neuronx_distributed_available():
import neuronx_distributed
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
class UnetNeuronWrapper(torch.nn.Module):
def __init__(self, model, input_names: List[str], device: Optional[str] = None):
super().__init__()
self.model = model
self.input_names = input_names
self.device = device
def forward(self, *inputs):
if len(inputs) != len(self.input_names):
raise ValueError(
f"The model needs {len(self.input_names)} inputs: {self.input_names}."
f" But only {len(inputs)} inputs are passed."
)
ordered_inputs = dict(zip(self.input_names, inputs))
added_cond_kwargs = {
"text_embeds": ordered_inputs.pop("text_embeds", None),
"time_ids": ordered_inputs.pop("time_ids", None),
"image_embeds": ordered_inputs.pop("image_embeds", None)
or ordered_inputs.pop("image_enc_hidden_states", None),
}
sample = ordered_inputs.pop("sample", None)
timestep = ordered_inputs.pop("timestep").float().expand((sample.shape[0],))
encoder_hidden_states = ordered_inputs.pop("encoder_hidden_states", None)
# Re-build down_block_additional_residual
down_block_additional_residuals = ()
down_block_additional_residuals_names = [
name for name in ordered_inputs.keys() if "down_block_additional_residuals" in name
]
for name in down_block_additional_residuals_names:
value = ordered_inputs.pop(name)
down_block_additional_residuals += (value,)
mid_block_additional_residual = ordered_inputs.pop("mid_block_additional_residual", None)
out_tuple = self.model(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
down_block_additional_residuals=(
down_block_additional_residuals if down_block_additional_residuals else None
),
mid_block_additional_residual=mid_block_additional_residual,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)
return out_tuple
class PixartTransformerNeuronWrapper(torch.nn.Module):
def __init__(self, model, input_names: List[str], device: str = None):
super().__init__()
self.model = model
self.dtype = model.dtype
self.input_names = input_names
self.device = device
def forward(self, *inputs):
if len(inputs) != len(self.input_names):
raise ValueError(
f"The model needs {len(self.input_names)} inputs: {self.input_names}."
f" But only {len(input)} inputs are passed."
)
ordered_inputs = dict(zip(self.input_names, inputs))
sample = ordered_inputs.pop("sample", None)
encoder_hidden_states = ordered_inputs.pop("encoder_hidden_states", None)
timestep = ordered_inputs.pop("timestep", None)
encoder_attention_mask = ordered_inputs.pop("encoder_attention_mask", None)
# Additional conditions
out_tuple = self.model(
hidden_states=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs={"resolution": None, "aspect_ratio": None},
return_dict=False,
)
return out_tuple
class ControlNetNeuronWrapper(torch.nn.Module):
def __init__(self, model, input_names: List[str], device: str = None):
super().__init__()
self.model = model
self.input_names = input_names
self.device = device
def forward(self, *inputs):
if len(inputs) != len(self.input_names):
raise ValueError(
f"The model needs {len(self.input_names)} inputs: {self.input_names}."
f" But only {len(input)} inputs are passed."
)
ordered_inputs = dict(zip(self.input_names, inputs))
sample = ordered_inputs.pop("sample", None)
timestep = ordered_inputs.pop("timestep", None)
encoder_hidden_states = ordered_inputs.pop("encoder_hidden_states", None)
controlnet_cond = ordered_inputs.pop("controlnet_cond", None)
conditioning_scale = ordered_inputs.pop("conditioning_scale", None)
# Additional conditions for the Stable Diffusion XL UNet.
added_cond_kwargs = {
"text_embeds": ordered_inputs.pop("text_embeds", None),
"time_ids": ordered_inputs.pop("time_ids", None),
}
out_tuple = self.model(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=controlnet_cond,
conditioning_scale=conditioning_scale,
added_cond_kwargs=added_cond_kwargs,
guess_mode=False, # TODO: support guess mode of ControlNet
return_dict=False,
**ordered_inputs,
)
return out_tuple
# Adapted from https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/inference/hf_pretrained_pixart_alpha_inference_on_inf2.ipynb
# For text encoding
class T5EncoderWrapper(torch.nn.Module):
def __init__(
self, model: "PreTrainedModel", sequence_length: int, batch_size: Optional[int] = None, device: str = "cpu"
):
super().__init__()
self.model = model
self.config = model.config
self.sequence_length = sequence_length
self.batch_size = batch_size
self.device = device
for block in self.model.encoder.block:
block.layer[1].DenseReluDense.act = torch.nn.GELU(approximate="tanh")
precomputed_bias = (
self.model.encoder.block[0].layer[0].SelfAttention.compute_bias(self.sequence_length, self.sequence_length)
)
self.model.encoder.block[0].layer[0].SelfAttention.compute_bias = lambda *args, **kwargs: precomputed_bias
def forward(self, input_ids, attention_mask):
return self.model(input_ids, attention_mask=attention_mask)
# Adapted from https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/torch-neuronx/t5-inference-tutorial.html
# For text encoding + KV cache initialization
class T5EncoderForSeq2SeqLMWrapper(torch.nn.Module):
"""Wrapper to trace the encoder and the kv cache initialization in the decoder."""
def __init__(
self,
model: "PreTrainedModel",
sequence_length: Optional[int] = None,
batch_size: Optional[int] = None,
num_beams: int = 1,
device: str = "xla",
tensor_parallel_size: int = 1,
):
super().__init__()
self.model = model
self.config = model.config
self.num_beams = num_beams
self.sequence_length = sequence_length
self.batch_size = batch_size
self.device = device
self.tensor_parallel_size = tensor_parallel_size
self.num_attention_heads_per_partition = self.config.num_heads # when tensor_parallel_size=1
if self.tensor_parallel_size > 1:
self.num_attention_heads_per_partition = (
self.num_attention_heads_per_partition
// neuronx_distributed.parallel_layers.parallel_state.get_tensor_model_parallel_size()
)
self.past_key_values_sa = torch.nn.ParameterList(
[
torch.nn.Parameter(
torch.ones(
(
self.num_beams * batch_size,
self.num_attention_heads_per_partition,
self.sequence_length - 1,
self.config.d_kv,
),
dtype=torch.float32,
),
requires_grad=False,
)
for _ in range(self.config.num_decoder_layers * 2)
]
)
self.past_key_values_ca = torch.nn.ParameterList(
[
torch.nn.Parameter(
torch.ones(
(
self.num_beams * batch_size,
self.num_attention_heads_per_partition,
self.sequence_length,
self.config.d_kv,
),
dtype=torch.float32,
),
requires_grad=False,
)
for _ in range(self.config.num_decoder_layers * 2)
]
)
def forward(self, input_ids, attention_mask):
# Infer shapes of dummy inputs used for tracing
batch_size = input_ids.shape[0]
sequence_length = input_ids.shape[1]
if self.sequence_length is not None:
assert self.sequence_length, (
f"Different sequence length for the parallel partition({self.sequence_length}) and for dummy inputs({sequence_length}). Make sure that they have the same value."
)
if self.batch_size is not None:
assert self.batch_size, (
f"Different batch size for the parallel partition({self.batch_size}) and for dummy inputs({batch_size}). Make sure that they have the same value."
)
encoder_output = self.model.encoder(
input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, output_hidden_states=False
)
last_hidden_state = encoder_output["last_hidden_state"]
encoder_hidden_states = torch.concat(
[tensor.unsqueeze(0).repeat(self.num_beams, 1, 1) for tensor in last_hidden_state]
)
decoder_blocks = self.model.decoder.block
present_key_value_states_sa = []
present_key_value_states_ca = []
for i, block in enumerate(decoder_blocks):
# Cross attention has to be initialized with the encoder hidden state
cross_attention: T5LayerCrossAttention = block.layer[1]
attention = cross_attention.EncDecAttention
def shape(states):
"""projection"""
return states.view(
self.num_beams * batch_size,
-1,
self.num_attention_heads_per_partition,
attention.key_value_proj_dim,
).transpose(1, 2)
key_states = shape(attention.k(encoder_hidden_states))
value_states = shape(attention.v(encoder_hidden_states))
if not self.tensor_parallel_size > 1:
# cross_attn_kv_state
present_key_value_states_ca.append(key_states)
present_key_value_states_ca.append(value_states)
# Self attention kv states are initialized to zeros. This is done to keep the size of the kv cache tensor constant.
# The kv cache is padded here to keep a fixed shape.
# [key states]
present_key_value_states_sa.append(
torch.zeros(
(self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv),
dtype=torch.float32,
device=self.device,
)
)
# [value states]
present_key_value_states_sa.append(
torch.zeros(
(self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv),
dtype=torch.float32,
device=self.device,
)
)
else:
present_key_value_states_ca.append((self.past_key_values_ca[i * 2] * 0) + key_states)
present_key_value_states_ca.append((self.past_key_values_ca[i * 2 + 1] * 0) + value_states)
present_key_value_states_sa.append(
self.past_key_values_sa[i * 2]
* torch.zeros(
(
self.num_beams * self.batch_size,
self.num_attention_heads_per_partition,
self.sequence_length - 1,
self.config.d_kv,
),
dtype=torch.float32,
device=self.device,
)
)
present_key_value_states_sa.append(
self.past_key_values_sa[i * 2 + 1]
* torch.zeros(
(
self.num_beams * self.batch_size,
self.num_attention_heads_per_partition,
self.sequence_length - 1,
self.config.d_kv,
),
dtype=torch.float32,
device=self.device,
)
)
return present_key_value_states_sa + present_key_value_states_ca
# Adapted from https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/torch-neuronx/t5-inference-tutorial.html
class T5DecoderWrapper(torch.nn.Module):
"""Wrapper to trace the decoder with past keys values with a language head."""
def __init__(
self,
model: "PreTrainedModel",
batch_size: int,
sequence_length: int,
num_beams: int = 1,
output_hidden_states: bool = False,
output_attentions: bool = False,
device: str = "xla",
tensor_parallel_size: int = 1,
):
super().__init__()
self.model = model
self.config = model.config
self.batch_size = batch_size
self.sequence_length = sequence_length
self.num_beams = num_beams
self.output_hidden_states = output_hidden_states
self.output_attentions = output_attentions
self.device = device
self.tensor_parallel_size = tensor_parallel_size
self.num_attention_heads_per_partition = self.config.num_heads
if tensor_parallel_size > 1:
self.num_attention_heads_per_partition = (
self.num_attention_heads_per_partition
// neuronx_distributed.parallel_layers.parallel_state.get_tensor_model_parallel_size()
)
# Initialize KV cache (num_beams, n_heads, seq_length, dim_per_head)
if device == "cpu":
self.past_key_values_sa = [
torch.ones(
(num_beams, self.config.num_heads, self.sequence_length - 1, self.config.d_kv), dtype=torch.float32
)
for _ in range(self.config.num_decoder_layers * 2)
]
self.past_key_values_ca = [
torch.ones(
(num_beams, self.config.num_heads, self.sequence_length, self.config.d_kv), dtype=torch.float32
)
for _ in range(self.config.num_decoder_layers * 2)
]
elif device == "xla":
self.past_key_values_sa = torch.nn.ParameterList(
[
torch.nn.Parameter(
torch.ones(
(
self.batch_size * self.num_beams,
self.num_attention_heads_per_partition,
sequence_length - 1,
self.config.d_kv,
),
dtype=torch.float32,
),
requires_grad=False,
)
for _ in range(self.config.num_decoder_layers * 2)
]
)
self.past_key_values_ca = torch.nn.ParameterList(
[
torch.nn.Parameter(
torch.ones(
(
self.batch_size * self.num_beams,
self.num_attention_heads_per_partition,
sequence_length,
self.config.d_kv,
),
dtype=torch.float32,
),
requires_grad=False,
)
for _ in range(self.config.num_decoder_layers * 2)
]
)
def update_past(self, past_key_values):
new_past_sa = []
new_past_ca = []
for past_layer in past_key_values:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
new_past_layer[i] = past_layer[i][:, :, 1:]
new_past_sa += [
new_past_layer[:2],
]
new_past_ca += [
new_past_layer[2:],
]
return new_past_sa, new_past_ca
def reorder_cache(self, past_key_values, beam_idx):
for i in range(len(past_key_values)):
gather_index = beam_idx.view([beam_idx.shape[0], 1, 1, 1]).expand_as(past_key_values[i])
past_key_values[i] = torch.gather(past_key_values[i], dim=0, index=gather_index)
return past_key_values
def forward(
self,
input_ids,
decoder_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
beam_idx,
beam_scores,
**kwargs,
):
if self.num_beams > 1:
# We reorder the cache based on the beams selected in each iteration. Required step for beam search.
past_key_values_sa = self.reorder_cache(self.past_key_values_sa, beam_idx)
past_key_values_ca = self.reorder_cache(self.past_key_values_ca, beam_idx)
else:
# We do not need to reorder for greedy sampling
past_key_values_sa = self.past_key_values_sa
past_key_values_ca = self.past_key_values_ca
# The cache is stored in a flatten form. We order the cache per layer before passing it to the decoder.
# Each layer has 4 tensors, so we group by 4.
past_key_values = [
[*past_key_values_sa[i * 2 : i * 2 + 2], *past_key_values_ca[i * 2 : i * 2 + 2]]
for i in range(0, int(len(past_key_values_ca) / 2))
]
decoder_output = self.model.decoder(
input_ids=input_ids,
attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
output_attentions=self.output_attentions,
output_hidden_states=self.output_hidden_states,
)
last_hidden_state = decoder_output["last_hidden_state"]
past_key_values = decoder_output["past_key_values"]
if self.output_hidden_states:
decoder_hidden_states = list(
decoder_output["hidden_states"]
) # flatten `hidden_states` which is a tuple of tensors
if self.output_attentions:
decoder_attentions = list(
decoder_output["attentions"]
) # flatten `decoder_attentions` which is a tuple of tensors
cross_attentions = list(
decoder_output["cross_attentions"]
) # flatten `cross_attentions` which is a tuple of tensors
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
last_hidden_state = last_hidden_state * (self.model.config.d_model**-0.5)
lm_logits = self.model.lm_head(last_hidden_state)
past_key_values_sa, past_key_values_ca = self.update_past(past_key_values)
# We flatten the cache to a single array. This is required for the input output aliasing to work
past_key_values_sa = [vec for kv_per_layer in past_key_values_sa for vec in kv_per_layer]
past_key_values_ca = [vec for kv_per_layer in past_key_values_ca for vec in kv_per_layer]
if self.device == "cpu":
self.past_key_values_sa = past_key_values_sa
self.past_key_values_ca = past_key_values_ca
# We calculate topk inside the wrapper
next_token_logits = lm_logits[:, -1, :]
if self.num_beams > 1:
# This section of beam search is run outside the decoder in the huggingface t5 implementation.
# To maximize the computation within the neuron device, we move this within the wrapper
logit_max, _ = torch.max(next_token_logits, dim=-1, keepdim=True)
logsumexp = torch.log(torch.exp(next_token_logits - logit_max).sum(dim=-1, keepdim=True))
next_token_scores = next_token_logits - logit_max - logsumexp
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(self.batch_size, self.num_beams * vocab_size)
next_token_scores = next_token_scores * 1
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * self.num_beams, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
neuron_outputs = [next_token_scores, next_tokens, next_indices] + past_key_values_sa + past_key_values_ca
else:
# Greedy
next_tokens = torch.argmax(next_token_logits, dim=-1)
neuron_outputs = [next_tokens] + past_key_values_sa + past_key_values_ca
if self.output_hidden_states:
neuron_outputs += decoder_hidden_states
if self.output_attentions:
neuron_outputs += decoder_attentions
neuron_outputs += cross_attentions
return neuron_outputs
class SentenceTransformersTransformerNeuronWrapper(torch.nn.Module):
def __init__(self, model, input_names: List[str], device: str = None):
super().__init__()
self.model = model
self.input_names = input_names
self.device = device
def forward(self, input_ids, attention_mask):
out_tuple = self.model({"input_ids": input_ids, "attention_mask": attention_mask})
return out_tuple["token_embeddings"], out_tuple["sentence_embedding"]
class CLIPVisionWithProjectionNeuronWrapper(torch.nn.Module):
def __init__(
self,
model,
input_names: List[str],
output_hidden_states: bool = True,
device: str = None,
):
super().__init__()
self.model = model
self.input_names = input_names
self.output_hidden_states = output_hidden_states
self.device = device
def forward(self, pixel_values):
vision_outputs = self.model.vision_model(
pixel_values=pixel_values, output_hidden_states=self.output_hidden_states
)
pooled_output = vision_outputs[1]
image_embeds = self.model.visual_projection(pooled_output)
outputs = (image_embeds, vision_outputs.last_hidden_state)
if self.output_hidden_states:
outputs += (vision_outputs.hidden_states,)
return outputs
class SentenceTransformersCLIPNeuronWrapper(torch.nn.Module):
def __init__(self, model, input_names: List[str], device: str = None):
super().__init__()
self.model = model
self.input_names = input_names
self.device = device
def forward(self, input_ids, pixel_values, attention_mask):
vision_outputs = self.model[0].model.vision_model(pixel_values=pixel_values)
image_embeds = self.model[0].model.visual_projection(vision_outputs[1])
text_outputs = self.model[0].model.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
)
text_embeds = self.model[0].model.text_projection(text_outputs[1])
if len(self.model) > 1:
image_embeds = self.model[1:](image_embeds)
text_embeds = self.model[1:](text_embeds)
return (text_embeds, image_embeds)
class WhisperEncoderWrapper(torch.nn.Module):
"""Wrapper to trace the forward of Whisper encoder."""
def __init__(
self,
model: "PreTrainedModel",
batch_size: int,
device: str = None,
**kwargs,
):
super().__init__()
self.model = model
self.config = model.config
self.batch_size = batch_size
self.device = device
def forward(
self,
input_features,
decoder_input_ids,
**kwargs,
):
# encoder
encoder_outputs = self.model.model.encoder(
input_features=input_features,
return_dict=True,
)
# 1st decoder + proj_out
decoder_outputs = self.model.model.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs[0],
use_cache=False,
return_dict=True,
)
lm_logits = self.model.proj_out(decoder_outputs[0])
return (lm_logits, encoder_outputs.last_hidden_state)
class WhisperDecoderWrapper(torch.nn.Module):
"""Wrapper to trace the forward of Whisper decoder."""
def __init__(
self,
model: "PreTrainedModel",
batch_size: int,
sequence_length: int,
output_hidden_states: bool = False,
output_attentions: bool = False,
device: str = None,
**kwargs,
):
super().__init__()
self.model = model
self.config = model.config
self.batch_size = batch_size
self.sequence_length = sequence_length
self.output_hidden_states = output_hidden_states
self.output_attentions = output_attentions
self.device = device if device else xm.xla_device()
def forward(
self,
input_ids,
encoder_hidden_states,
**kwargs,
):
cache_position = torch.arange(input_ids.shape[1]).to(self.device)
outputs = self.model.model.decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
cache_position=cache_position,
)
lm_logits = self.model.proj_out(outputs[0])
return lm_logits
class NoCacheModelWrapper(torch.nn.Module):
def __init__(self, model: "PreTrainedModel", input_names: List[str]):
super().__init__()
self.model = model
self.input_names = input_names
def forward(self, *input):
ordered_inputs = dict(zip(self.input_names, input))
outputs = self.model(use_cache=False, **ordered_inputs)
return outputs