optimum/neuron/models/inference/backend/modules/flashdecode/utils.py (44 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/modules/flashdecode/utils.py
from typing import Tuple
import torch
from neuronx_distributed.parallel_layers.utils import divide
# Theoretically one should enough to avoid writing true value to garbage pos (last pos of the cache),
# pick 128 as is more compatible with compiler
EXTRA_RESERVED_SPACE = 128
def get_cache_size(seq_len, num_cores_per_group):
return divide(seq_len, num_cores_per_group) + EXTRA_RESERVED_SPACE
def turn_2d_mask_to_4d(attention_mask, n_positions, batch_size):
return attention_mask[:, None, None, :].expand(batch_size, 1, 1, n_positions).to(torch.bool)
def calculate_num_cores_per_group(num_attn_heads, num_kv_heads, tp_degree):
assert num_attn_heads % tp_degree == 0, (
f"expect num attention heads is multiples of tp degree but got {num_attn_heads} and {tp_degree}"
)
num_cores_per_group = divide(min(tp_degree, num_attn_heads), num_kv_heads)
return num_cores_per_group
def mask_util(
pos_ids: torch.Tensor, rank_id: torch.Tensor, num_cores_per_group: int, cache_size: int
) -> (Tuple)[torch.Tensor, torch.Tensor]:
"""
@:param pos_ids: 2d [bsz x n_active_tokens] tensor represents position ids for all sequences in a batch
@:param rank_id: current rank of the device
@:return num_cores_per_group: number of cores per kv group
@:param cache_size: size of the cache per core
"""
assert pos_ids.dim() == 2, f"position ids have to be 2D for shape {pos_ids.shape}"
batch_sz, n_active_tokens = pos_ids.shape
# Core layout: 32 cores on 8 kv group (col) and 4 cores in each group
# 0, 1, 2, 3, 4, 5, 6, 7
# -------------------------------
# 0 | 0, 4, 8, 12, 16, 20, 24, 28
# 1 | 1, 5, 9, 13, 17, 21, 25, 29
# 2 | 2, 6, 10, 14, 18, 22, 26, 30
# 3 | 3, 7, 11, 15, 19, 23, 27, 31
# -------------------------------
# for rank id == 19:
# the rank_id_in_kv_group (row index) is 3, derived by 19 % 4
rank_id = torch.remainder(rank_id, num_cores_per_group)
# active masks: select only one core to update active KV
selected_core_idx = torch.remainder(pos_ids, num_cores_per_group)
active_masks = torch.where(selected_core_idx == rank_id, 1, 0).to(dtype=pos_ids.dtype)
if n_active_tokens > 1: # speculation
active_masks_causal = torch.full(
(n_active_tokens, n_active_tokens),
1,
device=pos_ids.device,
).tril(diagonal=0)
active_masks_causal = active_masks_causal[None, :, :].expand(batch_sz, n_active_tokens, n_active_tokens)
active_masks = active_masks[:, None, :].expand(batch_sz, n_active_tokens, n_active_tokens)
active_masks = torch.logical_and(active_masks, active_masks_causal).to(dtype=pos_ids.dtype)
# prior masks: infer and update it
# Cache layout within 1 kv group: 4 cores (row) and each has 8 positions (col), that is cache_size=8
# Note num of positions = bucket_sz//num_cores_per_kv_group
# 0, 1, 2, 3, 4, 5, 6, 7
# -------------------------------
# 0 | 0, 4, 8, 12, 16, 20, 24, 28
# 1 | 1, 5, 9, 13, 17, 21, 25, 29
# 2 | 2, 6, 10, 14, 18, 22, 26, 30
# 3 | 3, 7, 11, 15, 19, 23, 27, 31
# -------------------------------
# for pos_id = 19:
# the selected_pos for prior masks to be updated (col index) is 4, derived by 19 // 4
# selected_pos = torch.div(pos_ids, num_cores_per_group, rounding_mode="floor")
num_processed_tokens = pos_ids.min(dim=-1, keepdim=True).values if n_active_tokens > 1 else pos_ids
selected_pos = torch.div(
torch.subtract(torch.add(num_processed_tokens, num_cores_per_group - 1), rank_id),
num_cores_per_group,
rounding_mode="floor",
)
mask_shape = (batch_sz, n_active_tokens, cache_size) if n_active_tokens > 1 else (batch_sz, cache_size)
# init prior mask: set True from the start to the selected_pos, and the rest False
position_ids_to_compare = (
selected_pos.unsqueeze(-1).expand(mask_shape) if n_active_tokens > 1 else selected_pos.expand(mask_shape)
)
mask = torch.arange(cache_size, device=pos_ids.device).expand(mask_shape)
prior_masks = torch.where(position_ids_to_compare > mask, 1, 0).to(dtype=pos_ids.dtype)
return active_masks, prior_masks