in optimum/exporters/neuron/model_wrappers.py [0:0]
def forward(self, input_ids, attention_mask):
# Infer shapes of dummy inputs used for tracing
batch_size = input_ids.shape[0]
sequence_length = input_ids.shape[1]
if self.sequence_length is not None:
assert self.sequence_length, (
f"Different sequence length for the parallel partition({self.sequence_length}) and for dummy inputs({sequence_length}). Make sure that they have the same value."
)
if self.batch_size is not None:
assert self.batch_size, (
f"Different batch size for the parallel partition({self.batch_size}) and for dummy inputs({batch_size}). Make sure that they have the same value."
)
encoder_output = self.model.encoder(
input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, output_hidden_states=False
)
last_hidden_state = encoder_output["last_hidden_state"]
encoder_hidden_states = torch.concat(
[tensor.unsqueeze(0).repeat(self.num_beams, 1, 1) for tensor in last_hidden_state]
)
decoder_blocks = self.model.decoder.block
present_key_value_states_sa = []
present_key_value_states_ca = []
for i, block in enumerate(decoder_blocks):
# Cross attention has to be initialized with the encoder hidden state
cross_attention: T5LayerCrossAttention = block.layer[1]
attention = cross_attention.EncDecAttention
def shape(states):
"""projection"""
return states.view(
self.num_beams * batch_size,
-1,
self.num_attention_heads_per_partition,
attention.key_value_proj_dim,
).transpose(1, 2)
key_states = shape(attention.k(encoder_hidden_states))
value_states = shape(attention.v(encoder_hidden_states))
if not self.tensor_parallel_size > 1:
# cross_attn_kv_state
present_key_value_states_ca.append(key_states)
present_key_value_states_ca.append(value_states)
# Self attention kv states are initialized to zeros. This is done to keep the size of the kv cache tensor constant.
# The kv cache is padded here to keep a fixed shape.
# [key states]
present_key_value_states_sa.append(
torch.zeros(
(self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv),
dtype=torch.float32,
device=self.device,
)
)
# [value states]
present_key_value_states_sa.append(
torch.zeros(
(self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv),
dtype=torch.float32,
device=self.device,
)
)
else:
present_key_value_states_ca.append((self.past_key_values_ca[i * 2] * 0) + key_states)
present_key_value_states_ca.append((self.past_key_values_ca[i * 2 + 1] * 0) + value_states)
present_key_value_states_sa.append(
self.past_key_values_sa[i * 2]
* torch.zeros(
(
self.num_beams * self.batch_size,
self.num_attention_heads_per_partition,
self.sequence_length - 1,
self.config.d_kv,
),
dtype=torch.float32,
device=self.device,
)
)
present_key_value_states_sa.append(
self.past_key_values_sa[i * 2 + 1]
* torch.zeros(
(
self.num_beams * self.batch_size,
self.num_attention_heads_per_partition,
self.sequence_length - 1,
self.config.d_kv,
),
dtype=torch.float32,
device=self.device,
)
)
return present_key_value_states_sa + present_key_value_states_ca