def _pad_past_key_values()

in optimum/habana/transformers/generation/utils.py [0:0]


    def _pad_past_key_values(self, model_kwargs):
        # Early return if no past key values to pad
        past_key_values = model_kwargs.get("past_key_values")
        if not past_key_values:
            return

        # Determine if the model is MQA or not
        is_mqa_model = model_kwargs.get("mqa_model", False)
        lazy_mode = model_kwargs.get("lazy_mode", False)
        pad_amount = model_kwargs.get("kv_cache_pad_len", 0)
        kv_cache_len = model_kwargs.get("kv_cache_len", 0)
        kv_cache_len_pad_amount = kv_cache_len - pad_amount

        # For MQA models, past_key_values is a tensor
        if is_mqa_model:
            for i, layer in enumerate(past_key_values):  # Iterate over layers
                if torch.is_tensor(layer) and layer.shape[-2] == kv_cache_len_pad_amount:
                    # tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked
                    past_key_values[i] = torch.nn.functional.pad(layer, (0, 0, 0, pad_amount))
                    # Mark step if lazy mode is enabled
                    if lazy_mode:
                        self.htcore_generation.mark_step()
        # For Non-MQA models, the past_key_values is a list of lists (k and v)
        else:
            for i, layer in enumerate(past_key_values):  # Iterate over layers
                for j, k_or_v in enumerate(layer):  # Iterate over k and v
                    if torch.is_tensor(k_or_v) and k_or_v.shape[-2] == kv_cache_len_pad_amount:
                        # tensor(batch_size, n_heads, kv_cache_len, head_dim)
                        past_key_values[i][j] = torch.nn.functional.pad(k_or_v, (0, 0, 0, pad_amount))
                        # Mark step if lazy mode is enabled
                        if lazy_mode:
                            self.htcore_generation.mark_step()