in optimum/habana/transformers/generation/utils.py [0:0]
def _pad_past_key_values(self, model_kwargs):
# Early return if no past key values to pad
past_key_values = model_kwargs.get("past_key_values")
if not past_key_values:
return
# Determine if the model is MQA or not
is_mqa_model = model_kwargs.get("mqa_model", False)
lazy_mode = model_kwargs.get("lazy_mode", False)
pad_amount = model_kwargs.get("kv_cache_pad_len", 0)
kv_cache_len = model_kwargs.get("kv_cache_len", 0)
kv_cache_len_pad_amount = kv_cache_len - pad_amount
# For MQA models, past_key_values is a tensor
if is_mqa_model:
for i, layer in enumerate(past_key_values): # Iterate over layers
if torch.is_tensor(layer) and layer.shape[-2] == kv_cache_len_pad_amount:
# tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked
past_key_values[i] = torch.nn.functional.pad(layer, (0, 0, 0, pad_amount))
# Mark step if lazy mode is enabled
if lazy_mode:
self.htcore_generation.mark_step()
# For Non-MQA models, the past_key_values is a list of lists (k and v)
else:
for i, layer in enumerate(past_key_values): # Iterate over layers
for j, k_or_v in enumerate(layer): # Iterate over k and v
if torch.is_tensor(k_or_v) and k_or_v.shape[-2] == kv_cache_len_pad_amount:
# tensor(batch_size, n_heads, kv_cache_len, head_dim)
past_key_values[i][j] = torch.nn.functional.pad(k_or_v, (0, 0, 0, pad_amount))
# Mark step if lazy mode is enabled
if lazy_mode:
self.htcore_generation.mark_step()