backends/gaudi/server/text_generation_server/models/flash_causal_lm.py (2,113 lines of code) (raw):
import math
import os
import time
import torch
import torch.distributed
import numpy as np
from loguru import logger
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
PreTrainedTokenizerBase,
AutoConfig,
AutoTokenizer,
GenerationConfig,
)
from typing import (
Any,
Iterable,
Optional,
Tuple,
List,
Type,
Dict,
Union,
)
import torch.nn.functional as F
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.models import Model
from text_generation_server.utils.log import log_master
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
pad_next_token_chooser_parameters,
)
from text_generation_server.models.types import (
Batch,
Tokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import (
BLOCK_SIZE,
REQUEST_LOGPROBS,
TGI_WIGGLE_ROOM,
get_adapter_to_index,
)
from text_generation_server.layers.attention import (
KVCache,
KVCompressCache,
Seqlen,
HPUPagedAttentionMetadata,
trim_attn_metadata,
trim_seqlen_metadata,
_async_h2d_tensor_copy,
)
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
from text_generation_server.utils.import_utils import (
empty_cache,
synchronize,
get_free_memory,
)
from text_generation_server.utils.prefill_chunking import (
get_max_prefill_tokens,
)
import vllm_hpu_extension.environment as environment
import habana_frameworks.torch as htorch
import itertools
from vllm_hpu_extension.bucketing.common import get_bucketing_context
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__)
def prepare_for_decode(
dtype, use_contiguous_pa, device, slots, block_tables, batch_size, bucketing_ctx
):
# Prepare values if we need to continue decoding
# need for HPUPagedAttentionMetadata preparation
def flatten(in_list):
return list(itertools.chain(*in_list))
def gather_list(input, indices, v):
return [input[i] if i is not None else v for i in indices]
def pad_list(input, k, v):
input_len = len(input)
target_len = (input_len + k - 1) // k * k
padding = target_len - input_len
return input + [v] * padding
last_block_usage = [slot % BLOCK_SIZE + 1 for slot in slots]
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
block_usage = [
[BLOCK_SIZE] * (len(bt) - 1) + [lbu]
for bt, lbu in zip(block_tables, last_block_usage)
if bt
]
block_list = flatten(block_tables)
block_groups = flatten(block_groups)
block_usage = flatten(block_usage)
assert len(block_list) == len(block_groups)
assert len(block_list) == len(block_usage)
if use_contiguous_pa:
block_bucket_size = max(max(block_list) + 1, len(block_list))
if bucketing_ctx is not None:
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
block_bucket_size
)
indices: List[Any]
indices = [None] * block_bucket_size
for i, bid in enumerate(block_list):
indices[bid] = i
block_list = gather_list(block_list, indices, 0)
block_groups = gather_list(block_groups, indices, -1)
block_usage = gather_list(block_usage, indices, 1)
else:
block_bucket_size = len(block_list)
if bucketing_ctx is not None:
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
block_bucket_size
)
block_list = pad_list(block_list, block_bucket_size, 0)
block_groups = pad_list(block_groups, block_bucket_size, -1)
block_usage = pad_list(block_usage, block_bucket_size, 1)
block_list = torch.tensor(block_list, dtype=torch.int, device="cpu")
block_groups = torch.tensor(block_groups, dtype=torch.int, device="cpu")
block_usage = torch.tensor(block_usage, dtype=dtype, device="cpu")
block_list_device = _async_h2d_tensor_copy(block_list)
block_groups_device = _async_h2d_tensor_copy(block_groups)
block_usage_device = _async_h2d_tensor_copy(block_usage)
return trim_attn_metadata(
HPUPagedAttentionMetadata(
block_list=block_list_device,
block_groups=block_groups_device,
block_usage=block_usage_device,
block_mapping=None,
attn_bias=None,
)
)
@dataclass
class FlashCausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
# request id -> idx in list mapping
requests_idx_mapping: Dict[int, int]
# Decoder values
# Can be a list for easy filtering
# If `input_ids` is a list, it needs to be materialized to a tensor first
input_ids: Union[torch.Tensor, List[List[int]]]
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
position_ids: Optional[torch.Tensor]
speculative_ids: Optional[torch.Tensor]
# Set when creating the batch
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
slot_indices: Optional[torch.Tensor]
# list of length b of list of length s_i // block_size
block_tables: List[List[int]]
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor
# list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch
# used for filtering
cu_slots: torch.Tensor
max_input_length: int
max_current_length: int
# Whether this batch contains at least one request that is prefilling
prefilling: bool
# Whether each request is prefilling
prefilling_mask: List[bool]
# Prefill metadata tensors to efficiently compute logprobs
# tensor of length b + 1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor]
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_head_indices: Optional[torch.Tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_next_token_indices: Optional[torch.tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_cu_outlens: Optional[List[int]]
# Will be set by `generate_token` and reset after each prefill forward
prefill_logprob_tokens: List[Optional[Tokens]]
# All tokens
all_input_ids: List[List[int]]
all_input_ids_tensor: torch.Tensor
# Lengths of all generations present in the batch
input_lengths: List[int]
# size [b], containing the number of blocks that can be retrieved from the cache
cache_lengths: List[int]
prompt_lengths: List[int]
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
input_lengths_tensor: Optional[torch.Tensor]
cache_lengths_tensor: Optional[torch.Tensor]
prompt_lengths_tensor: torch.Tensor
prefix_offsets: List[Optional[int]]
read_offsets: List[Optional[int]]
# Generation helpers
next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Adapter metadata for each request
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
adapter_meta: Optional[AdapterBatchMetadata]
# Number of blocks in this batch
num_blocks: int
# Maximum number of blocks
max_blocks: int
hpu_attn_meta: Optional[HPUPagedAttentionMetadata]
next_token_logits: Optional[torch.Tensor]
speculative_logits: Optional[torch.Tensor]
valid_indices: Optional[List[int]]
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.num_blocks * BLOCK_SIZE,
current_tokens=(
sum([len(i) for i in self.input_ids])
if isinstance(self.input_ids, list)
else len(self.input_ids)
),
)
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer
):
max_length = 0
all_input_ids = []
batch_size = 0
for r in requests:
batch_size += 1
inputs = concat_text_chunks(r.input_chunks.chunks)
input_ids = tokenizer(
inputs,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
max_length = max(max_length, len(input_ids))
all_input_ids.append(input_ids)
return all_input_ids
@classmethod
def from_tokenized(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
batch_tokenized_inputs,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
cache_lengths = []
input_lengths = []
prompt_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
all_postfix_ids = []
requests_idx_mapping = {}
slots = []
cu_slots = [0]
next_token_chooser_parameters = []
stopping_criterias = []
top_n_tokens = []
num_blocks = 0
max_input_length = 0
max_current_length = 0
max_length = 0
max_blocks = 0
cu_blocks = [0]
block_tables = []
block_tables_ragged = []
# Parse batch
for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs)
):
### XXX: This consumes so much memory on long requests
### Deactivating it by default seems like the best course.
if not REQUEST_LOGPROBS:
r.prefill_logprobs = False
else:
assert False, "prefill_logprobs not supported yet"
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
prompt_length = len(tokenized_input)
prompt_lengths.append(prompt_length)
cache_length = r.cache_len
assert (
cache_length <= prompt_length
), f"Prefix {cache_length} vs input {prompt_length}"
if cache_length == prompt_length:
assert False, "unreachable"
# `chunk_len` is an optional field in the protobuf
# It is only set if the model support chunking
# Use all the remaining ids
postfix_ids = tokenized_input[cache_length:]
input_length = len(postfix_ids)
input_lengths.append(input_length)
prefix_offsets.append(prompt_length - 5)
read_offsets.append(prompt_length)
all_postfix_ids.append(postfix_ids)
all_input_ids.append(tokenized_input)
next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
# Paged attention
# Remove one as the first token des not have a past
speculative_length = get_speculate()
speculative_length = 0 if speculative_length is None else speculative_length
# Tokens that need to be mapped to blocks.
block_tokens = prompt_length + max_new_tokens - 1 + speculative_length
# blocks and slots can be empty (for example in warmup)
if not r.blocks:
needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
request_blocks = [
b for b in range(num_blocks, num_blocks + needed_blocks)
]
request_slots = [
s
for b in request_blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_blocks = r.blocks
request_slots = r.slots
block_tables.append(request_blocks)
block_tables_ragged.extend(request_blocks)
cu_blocks.append(len(block_tables_ragged))
slots.extend(request_slots)
cu_slots.append(len(slots))
cache_lengths.append(cache_length)
num_blocks += len(request_blocks)
# Update
max_blocks = max(max_blocks, len(request_blocks))
max_input_length = max(max_input_length, input_length)
max_current_length = max(max_current_length, cache_length + input_length)
max_length = max(
max_length,
prompt_length + max_new_tokens + speculative_length,
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer
)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros(
(len(all_input_ids), max_length), dtype=np.int64
)
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# put on cpu temporarily, move to hpu in prepare_for_prefill
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
block_tables_ragged = torch.tensor(block_tables_ragged, dtype=torch.int32)
cu_blocks = torch.tensor(cu_blocks, dtype=torch.int64)
block_tables_tensor = torch.empty(
(len(block_tables), max_blocks),
dtype=torch.int32,
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
prompt_lengths_tensor = torch.tensor(prompt_lengths, dtype=torch.int32)
slots = torch.tensor(slots, dtype=torch.int64)
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=all_postfix_ids,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
cache_lengths=cache_lengths,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=True,
prefilling_mask=[True] * len(pb.requests),
prefill_logprob_tokens=[None] * len(pb.requests),
input_lengths=input_lengths,
prompt_lengths=prompt_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=None,
prompt_lengths_tensor=prompt_lengths_tensor,
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids=None,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
slot_indices=None,
slots=slots,
cu_slots=cu_slots,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
cache_lengths_tensor=None,
input_lengths_tensor=None,
adapter_meta=None,
hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
valid_indices=None,
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
assert len(pb.requests) > 0
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
if len(request_ids) == len(self):
return self
device = self.block_tables_tensor.device
# New values after filtering
requests_idx_mapping = {}
# Used to index into tensors
indices = []
# slots to keep after filtering
slot_filtering_indices = torch.zeros(self.slots.shape[0], dtype=torch.bool)
# Create on CPU to only move to GPU once instead of at every copy
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_input_length = 0
max_current_length = 0
requests = []
block_tables = []
all_input_ids = []
input_ids = []
prompt_lengths = []
input_lengths = []
cache_lengths = []
prefix_offsets = []
read_offsets = []
cu_slots = [0]
prefilling_mask = []
prefill_logprob_tokens = []
stopping_criterias = []
adapter_set = set()
num_blocks = 0
max_blocks = 0
max_slots = 0
cumulative_slot_tokens = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
# Prefilling
request_prefilling = self.prefilling_mask[idx]
prefilling_mask.append(request_prefilling)
# Get length
request_input_length = self.input_lengths[idx]
request_cache_length = self.cache_lengths[idx]
max_input_length = max(max_input_length, request_input_length)
max_current_length = max(
max_current_length, request_cache_length + request_input_length
)
all_input_ids.append(self.all_input_ids[idx])
prompt_lengths.append(self.prompt_lengths[idx])
input_lengths.append(request_input_length)
cache_lengths.append(request_cache_length)
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
adapter_set.add(adapter_index)
request_block_table = self.block_tables[idx]
num_blocks += len(request_block_table)
block_tables.append(request_block_table)
start_slot = self.cu_slots[idx]
end_slot = self.cu_slots[idx + 1]
slot_length = end_slot - start_slot
# Set slice
slot_filtering_indices[start_slot:end_slot] = True
cu_slots.append(cumulative_slot_tokens + slot_length)
# Input ids if the request was part of a prefilling batch
# If the batch was decoding we can index into the tensor directly later
if self.prefilling:
input_ids.append(self.input_ids[idx])
else:
# Copy to tensor (CPU)
slot_indices[i] = cumulative_slot_tokens + request_cache_length
cumulative_slot_tokens += slot_length
max_blocks = max(max_blocks, len(request_block_table))
max_slots = max(max_slots, slot_length)
block_tables_tensor = self.block_tables_tensor[indices]
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
slots = self.slots[slot_filtering_indices]
if self.prefilling:
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
slot_indices = None
cache_lengths_tensor = None
input_lengths_tensor = None
adapter_meta = None
else:
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
cache_lengths_tensor = self.cache_lengths_tensor[indices]
# Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device)
if self.adapter_meta is not None:
adapter_indices = self.adapter_meta.adapter_indices[indices]
adapter_segments, adapter_segment_indices = find_segments(
adapter_indices
)
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
else:
adapter_meta = None
htorch.core.mark_step()
return type(self)(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
slots=slots,
cu_slots=cu_slots,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=self.prefilling,
prefilling_mask=prefilling_mask,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
cache_lengths=cache_lengths,
cache_lengths_tensor=cache_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=self.all_input_ids_tensor,
next_token_chooser=self.next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=self.top_n_tokens,
top_n_tokens_tensor=self.top_n_tokens_tensor,
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=self.speculative_ids,
adapter_meta=adapter_meta,
hpu_attn_meta=None,
valid_indices=indices,
next_token_logits=self.next_token_logits,
speculative_logits=self.speculative_logits,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(
cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0
) -> "FlashCausalLMBatch":
# Batch attributes
requests = []
requests_idx_mapping = {}
prefilling = False
num_blocks = 0
total_batch_size = 0
total_slots = 0
max_blocks = 0
max_length = 0
max_input_length = 0
max_current_length = 0
ADAPTER_TO_INDEX = get_adapter_to_index()
for b in batches:
total_batch_size += len(b)
max_blocks = max(max_blocks, b.max_blocks)
total_slots += len(b.slots)
num_blocks += b.num_blocks
speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
)
max_input_length = max(max_input_length, b.max_input_length)
max_current_length = max(max_current_length, b.max_current_length)
max_length = max(
max_length,
max(
prompt_length
+ stopping_criteria.max_new_tokens
+ speculative_length
for prompt_length, stopping_criteria in zip(
b.prompt_lengths, b.stopping_criterias
)
),
)
prefilling = prefilling or b.prefilling
slots = batches[0].slots.new_empty(total_slots)
cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64)
if prefilling:
input_ids = []
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
slot_indices = None
cache_lengths_tensor = None
input_lengths_tensor = None
adapter_meta = None
adapter_segment_builder = None
else:
if padded_total_bs == batches[0].input_ids.shape[0]:
input_ids = batches[0].input_ids
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
if (
batches[0].position_ids is not None
and batches[0].position_ids.dim() == 2
):
# Qwen2_vl case:
position_ids = batches[0].position_ids.new_empty(
(total_batch_size, batches[0].position_ids.shape[-1])
)
else:
position_ids = batches[0].position_ids.new_empty(total_batch_size)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
total_batch_size
)
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
total_batch_size
)
if ADAPTER_TO_INDEX:
total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches
)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
total_indices_size
)
adapter_segment_builder = SegmentConcatBuilder()
adapter_set = set()
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
total_batch_size
)
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks)
)
all_input_ids_tensor = batches[0].all_input_ids_tensor
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
block_tables = []
cache_lengths = []
all_input_ids = []
prompt_lengths = []
input_lengths = []
prefix_offsets = []
read_offsets = []
prefill_logprob_tokens = []
next_token_chooser_parameters = []
fsm_grammar_states = []
stopping_criterias = []
top_n_tokens = []
prefilling_mask = []
# Cumulative length
cumulative_batch_size = 0
cumulative_slots = 0
cumulative_adapter_indices_size = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
valid_bsize = len(batch)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
start_index = cumulative_batch_size
end_index = cumulative_batch_size + valid_bsize
index = torch.tensor(list(range(start_index, end_index)), device="cpu")
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
if i > 0:
all_input_ids_tensor.index_copy_(
0,
index.to(batch.all_input_ids_tensor.device),
batch.all_input_ids_tensor[:valid_bsize, :],
)
block_tables_tensor[
start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks]
prompt_lengths_tensor.index_copy_(0, index, batch.prompt_lengths_tensor)
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
slot_index = torch.tensor(
list(range(slots_start_index, slots_end_index)),
device=batch.slots.device,
)
slots.index_copy_(0, slot_index, batch.slots)
cu_slots[start_index + 1 : end_index + 1] = (
batch.cu_slots[1:] + cumulative_slots
)
if not prefilling:
if padded_total_bs != batches[0].input_ids.shape[0] or i > 0:
input_ids.index_copy_(
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
)
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
slot_indices.index_copy_(
0, index, batch.slot_indices + cumulative_slots
)
input_lengths_tensor.index_copy_(
0, index, batch.input_lengths_tensor[:valid_bsize]
)
cache_lengths_tensor.index_copy_(
0, index, batch.cache_lengths_tensor[:valid_bsize]
)
if ADAPTER_TO_INDEX:
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
+ batch.adapter_meta.adapter_indices.shape[0]
)
adapter_indices[adapter_start_index:adapter_end_index] = (
batch.adapter_meta.adapter_indices
)
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments,
batch.adapter_meta.segment_indices,
)
else:
if isinstance(batch.input_ids, torch.Tensor):
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
input_ids.extend(batch.input_ids)
prefilling_mask.extend(batch.prefilling_mask)
block_tables.extend(batch.block_tables)
cache_lengths.extend(batch.cache_lengths)
all_input_ids.extend(batch.all_input_ids)
prompt_lengths.extend(batch.prompt_lengths)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
# Update
cumulative_slots += len(batch.slots)
cumulative_batch_size += len(batch)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device,
tokenizer=batches[0].next_token_chooser.tokenizer,
fsm_grammar_states=fsm_grammar_states,
)
# We skip computing the speculative_ids when the batch size is too large, so
# we must check that all batches have them, otherwise they must be discarded
if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
else:
speculative_ids = None
if ADAPTER_TO_INDEX and adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
cache_lengths=cache_lengths,
cache_lengths_tensor=cache_lengths_tensor,
slots=slots,
cu_slots=cu_slots,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=prefilling,
prefilling_mask=prefilling_mask,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None,
hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
valid_indices=None,
)
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id):
block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths]
block_tables = []
for i, bt in enumerate(self.block_tables):
block_tables.append(bt[0 : block_num[i]])
if bucketing_ctx is not None:
padded_bs = bucketing_ctx.get_padded_decode_batch_size(
self.input_ids.shape[0]
)
else:
padded_bs = self.input_ids.shape[0]
slots = self.slots[self.slot_indices]
self.hpu_attn_meta = prepare_for_decode(
dtype,
use_contiguous_pa,
"hpu",
slots,
block_tables,
padded_bs,
bucketing_ctx,
)
self.input_ids = F.pad(
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=pad_token_id
)
if self.position_ids.dim() == 2:
# Qwen VL case
self.position_ids = F.pad(
self.position_ids,
(0, 0, 0, padded_bs - self.position_ids.shape[0]),
value=1,
)
else:
self.position_ids = F.pad(
self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1
)
self.input_lengths_tensor = F.pad(
self.input_lengths_tensor,
(0, padded_bs - self.input_lengths_tensor.shape[0]),
value=0,
)
self.cache_lengths_tensor = F.pad(
self.cache_lengths_tensor,
(0, padded_bs - self.cache_lengths_tensor.shape[0]),
value=0,
)
next_token_chooser_parameters = []
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs)
# update past grammar states
fsm_grammar_states = [0] * padded_bs
for i, req in enumerate(self.requests):
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
self.next_token_chooser.dtype,
self.next_token_chooser.device,
self.next_token_chooser.tokenizer,
fsm_grammar_states,
)
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
):
# Prepare values if we need to continue prefilling
# Speculation must be ignored while we prefill even with chunking
# it simplifies everything
assert self.speculative_ids is None
# device = self.block_tables_tensor.device
# hpu does not support varlen for prefill, use sdpa instead. so need to pad input_tensor, position
# padding to left to work with sliding window
# use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate
# the right logit position
input_ids_padded_length = []
# need extra pad to match warmup seq
extra_pad = max_padded_input_len - self.max_input_length
extra_pad_bs = max_padded_bs - len(self)
device = "hpu"
if isinstance(self.input_ids, list) and len(self) > 1:
input_ids_padded_length = []
input_ids = []
for input_id in self.input_ids:
padded = self.max_input_length - len(input_id) + extra_pad
if padded > 0:
input_id = [pad_token_id] * padded + input_id
input_ids.append(input_id)
input_ids_padded_length.append(padded)
input_ids = np.concatenate(input_ids, dtype=np.int64)
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
elif isinstance(self.input_ids, list):
input_ids = self.input_ids[0]
input_ids_padded_length.append(extra_pad)
input_ids = [pad_token_id] * extra_pad + input_ids
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
else:
input_ids = torch.full(
(max_padded_input_len * len(self),),
pad_token_id,
dtype=torch.int64,
device=self.input_ids.device,
)
src_pos = 0
for i in range(len(self)):
end_pos = (i + 1) * max_padded_input_len
start_pos = end_pos - self.input_lengths[i]
input_ids[start_pos:end_pos] = self.input_ids[
src_pos : src_pos + self.input_lengths[i]
]
input_ids_padded_length.append(
max_padded_input_len - self.input_lengths[i]
)
src_pos += self.input_lengths[i]
self.input_ids = input_ids
self.input_ids = F.pad(
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=pad_token_id
)
self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32)
self.input_lengths_tensor = F.pad(
self.input_lengths_tensor, (0, extra_pad_bs), value=0
)
cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(max_padded_bs + 1)
torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
self.cache_lengths_tensor = torch.tensor(self.cache_lengths, dtype=torch.int32)
self.cache_lengths_tensor = F.pad(
self.cache_lengths_tensor, (0, extra_pad_bs), value=0
)
position_ids = []
slot_indices = []
prefill_cache_indices = []
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_cu_outlens = [0]
# Cumulative length
cumulative_length = 0
cumulative_slot_tokens = 0
prefill_out_cumulative_length = 0
adapter_indices_list = []
adapter_set = set()
for i, (
r,
cache_length,
input_length,
prompt_length,
request_prefilling,
blocks,
) in enumerate(
zip(
self.requests,
self.cache_lengths,
self.input_lengths,
self.prompt_lengths,
self.prefilling_mask,
self.block_tables,
)
):
next_chunk_length = input_length
# Position ids
request_position_ids = torch.arange(
cache_length, cache_length + input_length, dtype=torch.int32
)
request_position_ids = F.pad(
request_position_ids, (input_ids_padded_length[i], 0), value=1
)
position_ids.append(request_position_ids)
if not r.slots:
request_slots = [
s
for b in blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_slots = r.slots
request_slot_indices = torch.arange(
cache_length + cumulative_slot_tokens,
cache_length + cumulative_slot_tokens + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
# Update
cumulative_slot_tokens += len(request_slots)
# Create tensor to slice into the kv tensor in prefill
# hpu need request_prefill_cache_indices to skip padding in kv cache
sliding_window = input_length
cumulative_length += input_ids_padded_length[i]
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length,
dtype=torch.int64,
)
# Prefill logprobs is ignored if the request is done prefilling
prefill_logprobs = r.prefill_logprobs and request_prefilling
all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
if prefill_logprobs:
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
prefill_cache_indices.append(request_prefill_cache_indices)
ADAPTER_TO_INDEX = get_adapter_to_index()
if ADAPTER_TO_INDEX:
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(
torch.full((next_chunk_length,), adapter_index)
)
adapter_set.add(adapter_index)
# Update
cumulative_length += next_chunk_length
if not all_prefill_logprobs and not no_prefill_logprobs:
prefill_head_indices = []
prefill_next_token_indices = []
# Cumulative length
cumulative_length = 0
prefill_out_cumulative_length = 0
for i, (
r,
input_length,
request_prefilling,
) in enumerate(
zip(
self.requests,
self.input_lengths,
self.prefilling_mask,
)
):
# Prefill logprobs is ignored if the request is done prefilling
prefill_logprobs = r.prefill_logprobs and request_prefilling
if prefill_logprobs:
prefill_head_indices.append(
torch.arange(
cumulative_length,
cumulative_length + input_length,
dtype=torch.int32,
)
)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1],
dtype=torch.int32,
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_out_cumulative_length += 1
# Update
cumulative_length += input_length
if len(self) > 1:
if position_ids:
position_ids = torch.cat(position_ids)
if slot_indices:
slot_indices = torch.cat(slot_indices)
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
if position_ids:
position_ids = position_ids[0]
if slot_indices:
slot_indices = slot_indices[0]
prefill_cache_indices = prefill_cache_indices[0]
self.position_ids = position_ids
self.position_ids = F.pad(
self.position_ids, (0, extra_pad_bs * max_padded_input_len), value=1
)
self.slot_indices = slot_indices
self.prefill_cu_outlens = prefill_cu_outlens
self.prefill_cache_indices = torch.zeros_like(
self.input_ids, dtype=torch.bool, device="cpu"
)
self.prefill_cache_indices[prefill_cache_indices] = True
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = self.cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.cat(prefill_head_indices)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64
)
self.prefill_head_indices = prefill_head_indices
self.prefill_next_token_indices = prefill_next_token_indices
input_ids_padded_length_tensor = torch.cumsum(
torch.tensor(input_ids_padded_length, dtype=torch.int32),
dim=-1,
).to(torch.int32)
input_ids_padded_length_tensor = F.pad(
input_ids_padded_length_tensor, (0, extra_pad_bs), value=0
)
if self.prefill_head_indices is not None:
self.prefill_head_indices = (
self.prefill_head_indices + input_ids_padded_length_tensor
)
if self.prefill_next_token_indices is not None:
self.prefill_next_token_indices = (
self.prefill_next_token_indices + input_ids_padded_length_tensor
)
all_input_ids_tensor = torch.full(
(max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])),
pad_token_id,
dtype=torch.int64,
device="hpu",
)
for i in range(len(self)):
all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = (
self.all_input_ids_tensor[i]
)
self.all_input_ids_tensor = all_input_ids_tensor
next_token_chooser_parameters = []
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs)
# update past grammar states
fsm_grammar_states = [0] * max_padded_bs
for i, req in enumerate(self.requests):
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
self.next_token_chooser.dtype,
self.next_token_chooser.device,
self.next_token_chooser.tokenizer,
fsm_grammar_states,
)
if ADAPTER_TO_INDEX:
if adapter_set:
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
adapter_segments, adapter_segment_indices = find_segments(
adapter_indices
)
else:
adapter_indices = torch.zeros_like(self.input_ids)
adapter_segments = [0, len(adapter_indices)]
adapter_segment_indices = [len(adapter_indices) - 1]
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
self.adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
def __len__(self):
return len(self.requests)
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashCausalLM(Model):
def __init__(
self,
model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
lora_adapter_ids: Optional[list] = [],
tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
config_class: PreTrainedTokenizerBase = AutoConfig,
default_dtype=torch.float16,
aliases=None,
# Used for Santacoder override of config
num_kv_heads: Optional[int] = None,
# Deepseek V2 uses different QK and V dims.
head_size: Optional[int] = None,
skip_special_tokens: bool = True,
kv_cache_dtype: Optional[torch.dtype] = None,
support_chunking: bool = True,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
if world_size > 1:
self.process_group_cpu = torch.distributed.new_group(backend="gloo")
device = torch.device("hpu")
dtype = torch.bfloat16 if dtype is None else dtype
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = config_class.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(quantize, model_id, revision)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
aliases=aliases,
weights_loader=weights_loader,
)
prefix = None
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
# VLM models define the config we care about in their text_config
text_config = getattr(config, "text_config", None)
if text_config is not None:
config = text_config
if getattr(config, "sliding_window", None) is None:
config.sliding_window = None
self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads // self.process_group.size()
self.config = config
# Validation is done in the model itself
if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None)
# GPT-2 workaround
if num_kv_heads is None:
num_kv_heads = getattr(config, "n_head", None)
if num_kv_heads is None:
raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = (
num_kv_heads // self.process_group.size()
if num_kv_heads // self.process_group.size() > 0
else num_kv_heads
)
assert self.num_kv_heads > 0
if head_size is None:
# Some models use GQA and different sizes for o_proj
# and q_proj, that allows for that.
if getattr(config, "head_dim", None) is not None:
self.head_size = config.head_dim
else:
self.head_size = config.hidden_size // config.num_attention_heads
else:
self.head_size = head_size
self.cuda_graphs = {}
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
self.bucketing_ctx = None
self.max_total_tokens = None
self.max_input_tokens = None
htorch.core.hpu_set_env()
if htorch.utils.internal.is_lazy():
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
environment.set_model_config(self.config)
self.use_contiguous_pa = (
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
)
self.limit_hpu_graph = (
os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true"
)
self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true"
self.max_seq_len_to_capture = 8192
if tokenizer.pad_token_id is None:
if config.pad_token_id is not None:
tokenizer.pad_token_id = config.pad_token_id
elif config.eos_token_id is not None:
tokenizer.pad_token_id = (
config.eos_token_id[0]
if isinstance(config.eos_token_id, list)
else config.eos_token_id
)
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.pad_token_id = 0
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
support_chunking=support_chunking,
)
@property
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch
def max_past(self) -> int:
return getattr(self.model, "max_past", None)
def init_kv_cache(
self,
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
self.kv_cache = []
empty_cache()
if self.config.model_type == "deepseek_v3":
self.kv_cache = [
KVCompressCache(
num_blocks=num_blocks,
head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
dtype=dtype,
device=device,
)
for _ in range(num_layers)
]
else:
self.kv_cache = [
KVCache(
num_blocks=num_blocks,
num_heads=num_heads,
head_size=head_size,
dtype=dtype,
device=device,
)
for _ in range(num_layers)
]
def warmup(
self,
batch: FlashCausalLMBatch,
max_input_tokens: Optional[int],
max_total_tokens: Optional[int],
):
if os.environ.get("MAX_BATCH_SIZE") is None:
raise RuntimeError(
"MAX_BATCH_SIZE is not set, it should be set in the launcher "
"using `--max-batch-size xxx`"
)
# The warmup batch is the biggest batch we could ever receive
self.kv_cache = []
empty_cache()
self.graphed_buckets = set()
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
if self.config.model_type == "deepseek_v3":
cache_block_size = BLOCK_SIZE * (
self.config.kv_lora_rank + self.config.qk_rope_head_dim
)
else:
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
cache_block_size = cache_block_size * 2
total_cache_size = self.num_layers * cache_block_size * dtype_size
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
self.mem_reserved = int(free_memory * (1 - MEMORY_FRACTION))
graph_reserved_mem = (
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
if htorch.utils.internal.is_lazy()
else 0
)
mem_used_from_graph = int(
(free_memory - self.mem_reserved) * graph_reserved_mem
)
log_master(
logger.info,
f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
)
if max_total_tokens is None:
max_total_tokens = sum(batch.input_lengths)
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
self.max_total_tokens = max_total_tokens
self.max_input_tokens = max_input_tokens
try:
self.init_kv_cache(
batch.num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.kv_cache_dtype,
self.device,
)
batch_num_blocks = batch.num_blocks
num_tokens = batch.to_pb().current_tokens
synchronize(self.device)
_, _batch, _ = self.generate_token([batch])
except Exception:
raise RuntimeError(
f"Not enough memory to handle {num_tokens} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`"
)
synchronize(self.device)
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
kv_memory = free_memory - self.mem_reserved - mem_used_from_graph
num_blocks = (
# Leave 5% for some wiggle room
int(kv_memory // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch_num_blocks
)
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
self.kv_cache = []
empty_cache()
self.init_kv_cache(
num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.kv_cache_dtype,
self.device,
)
self.max_batch_prefill_tokens = get_max_prefill_tokens()
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
HPUBucketingContext = get_bucketing_context()
# need to warmup one more step since block is allocated from 1
block_step = os.getenv("VLLM_DECODE_BLOCK_BUCKET_STEP", BLOCK_SIZE)
max_total_tokens_aligned = math.ceil(
max_total_tokens / BLOCK_SIZE
) * BLOCK_SIZE + math.ceil(block_step * BLOCK_SIZE / max_num_seqs)
model_max_length = self.tokenizer.model_max_length
max_position_embeddings = getattr(
self.config, "max_position_embeddings", model_max_length
)
self.bucketing_ctx = HPUBucketingContext(
max_num_seqs,
max_num_seqs, # self.max_num_prefill_seqs, #TODO
BLOCK_SIZE,
max_num_seqs * max_total_tokens_aligned,
False,
min(model_max_length, max_position_embeddings),
max_input_tokens,
max_total_tokens_aligned,
)
max_blocks = max(
BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
)
self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)
synchronize(self.device)
if self.skip_warmup:
self.bucketing_ctx.generate_prompt_buckets()
self.bucketing_ctx.generate_decode_buckets(
self.bucketing_ctx.num_hpu_blocks
)
log_master(
logger.info, "skip warmup hpu graph, not recommmended, may cause OOM"
)
del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
self.warmup_hpu_graph(batch)
del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def log_warmup(self, prefilling, i, max_i, batch_size, seq_len):
free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory())
phase = "Prompt" if prefilling else "Decode"
dim = "seq_len" if prefilling else "num_blocks"
graphed_bucket = (batch_size, seq_len, prefilling)
bypass = graphed_bucket not in self.graphed_buckets
msg = (
f"[Warmup][{phase}][{i+1}/{max_i}] "
f"batch_size:{batch_size} "
f"{dim}:{seq_len} "
f"bypass:{bypass} "
f"free_mem:{free_mem}"
", this may take a while..."
)
log_master(logger.info, msg)
def use_graphs(self, prefill, seq_len, batch_size):
if self.limit_hpu_graph and prefill:
return False
if self.skip_warmup:
return True
return (batch_size, seq_len, prefill) in self.graphed_buckets
def align_workers(self, value, op):
if self.world_size <= 1:
return value
value_t = torch.tensor(value, device="cpu")
torch.distributed.all_reduce(value_t, op=op, group=self.process_group_cpu)
return value_t.item()
def warmup_hpu_graph(self, batch):
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets()
def ordering_function_min_tokens(b):
return (b[0] * b[1], b[1], b[0])
buckets = list(
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
total_batch_seq = 0.001
total_mem = 0
available_mem = prompt_available_memory
msg = (
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
)
log_master(logger.info, msg)
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, seq_len, True)
if not (
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
):
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Prefill warmup successful.\n")
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
buckets = list(
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
)
free_mem = HabanaMemoryProfiler.current_free_device_memory()
total_batch_seq = 0.001
total_mem = 0
available_mem = free_mem - self.mem_reserved
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
):
input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat(
batch_size
)
position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat(
batch_size
)
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1)
slot_acc = []
for i in range(batch_size):
slots = []
for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
slot_acc.extend(slots[:prompt_len])
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype)
input_lengths = torch.ones(batch_size, dtype=torch.int32) * prompt_len
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
lm_head_indices = input_lengths - 1
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
True, prompt_len, batch_size
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=self.kv_cache,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
adapter_data=None,
hpu_attention_meta=None,
**kwargs,
)
def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch):
input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
block_tables = []
slots = []
start_idx = 0
# fetch the last blocked to warmup block num
for i in range(batch_size):
block_array = list(range(start_idx, start_idx + blocks[i]))
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
input_lengths = torch.ones(batch_size, dtype=torch.int32)
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
hpu_attention_meta = prepare_for_decode(
self.dtype,
self.use_contiguous_pa,
self.device,
slots,
block_tables,
batch_size,
bucketing_ctx=None,
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
False, hpu_attention_meta.block_list.shape[0], batch_size
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=_async_h2d_tensor_copy(slots_tensor),
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=None,
adapter_data=None,
hpu_attention_meta=hpu_attention_meta,
**kwargs,
)
def forward(
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
# Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
# then update the slots with the additional indices to ensure we're grabbing the ones that have been
# allocated
slot_indices = (
batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
slots = batch.slots[slot_indices]
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
else:
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots
slots = slots_pad
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
kwargs = {}
if htorch.utils.internal.is_lazy():
batch_size = input_lengths.shape[0]
prompt_len = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, prompt_len, batch_size
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
# TODO not support adapter now, need the add in the future
adapter_data=None,
hpu_attention_meta=batch.hpu_attn_meta,
**kwargs,
)
return logits, speculative_logits
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batches: List[FlashCausalLMBatch]
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
# Stage 1. Collect next token ids of any previously started generations
start = time.time_ns()
prev_batches = []
requests_to_generate = []
for batch_id, batch in enumerate(batches):
if batch.next_token_logits is not None:
prefill = batch.prefilling
if batch.prefilling:
batch.prefilling = False
batch.prefilling_mask = [False] * len(batch)
speculate = get_speculate()
(
next_input_ids,
next_token_logprobs,
logprobs,
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[
: batch.next_token_logits.shape[0], : batch.max_current_length
],
batch.next_token_logits,
speculate,
batch.speculative_ids,
batch.speculative_logits,
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
_async_h2d_tensor_copy(batch.top_n_tokens_tensor),
logprobs,
accepted_ids,
)
if batch.valid_indices is not None:
# TODO speculative decoding handling missing
index = torch.arange(
0,
len(batch.valid_indices),
device=batch.all_input_ids_tensor.device,
)
batch.all_input_ids_tensor.index_copy_(
0, index, batch.all_input_ids_tensor[batch.valid_indices]
)
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
len(batch.valid_indices)
)
next_input_ids.index_copy_(
0, index, next_input_ids[batch.valid_indices]
)
next_input_ids = next_input_ids[:padded_total_bs]
next_token_logprobs.index_copy_(
0, index, next_token_logprobs[batch.valid_indices]
)
accepted_ids.index_copy_(
0, index, accepted_ids[batch.valid_indices]
)
if speculative_ids is not None:
speculative_ids = speculative_ids[batch.valid_indices]
batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[
batch.valid_indices
]
top_n_tokens = []
batch_top_token_ids_v = []
batch_top_token_logprobs_v = []
for i in batch.valid_indices:
top_n_tokens.append(batch.top_n_tokens[i])
batch_top_token_ids_v.append(batch_top_token_ids[i])
batch_top_token_logprobs_v.append(batch_top_token_logprobs[i])
batch_top_token_ids = batch_top_token_ids_v
batch_top_token_logprobs = batch_top_token_logprobs_v
batch.top_n_tokens = top_n_tokens
batch.next_token_chooser = batch.next_token_chooser.filter(
batch.valid_indices
)
batch.valid_indices = None
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
# instantly become of shape [BATCH_SIZE]
if prefill:
indices = batch.cu_seqlen_prefill[1:] - 1
# pad in left
if batch.prefill_cache_indices is not None:
batch.position_ids = batch.position_ids[
batch.prefill_cache_indices
][indices]
else:
batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
if batch.adapter_meta is not None:
batch.adapter_meta.adapter_indices = (
batch.adapter_meta.adapter_indices[indices]
)
# For each member of the batch
# Cumulative length
if batch.speculative_logits is not None:
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
for i in range(len(batch)):
batch.all_input_ids_tensor[
i,
batch.cache_lengths[i]
+ batch.input_lengths[i] : batch.cache_lengths[i]
+ batch.input_lengths[i]
+ accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
accepted_ids = accepted_ids.cpu()
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += (
batch.input_lengths_tensor + accepted_ids - 1
)
batch.input_lengths_tensor = torch.ones_like(
batch.input_lengths_tensor
)
batch.slot_indices += accepted_ids[: len(batch)]
else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
index = F.pad(
index, (0, next_input_ids.shape[0] - index.shape[0]), value=0
)
index = index.to(batch.all_input_ids_tensor.device)
batch_idx = torch.arange(
0,
index.shape[0],
dtype=torch.long,
device=batch.all_input_ids_tensor.device,
)
batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids
)
batch.input_ids = next_input_ids
batch.position_ids += 1
batch.cache_lengths_tensor += batch.input_lengths_tensor
batch.input_lengths_tensor = torch.ones_like(
batch.input_lengths_tensor
)
batch.slot_indices += 1
batch.speculative_ids = speculative_ids
# Does a HPU <-> CPU sync internally
if prefill and batch.adapter_meta is not None:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(
batch.adapter_meta.adapter_indices
)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
prev_batches.append(
{
"next_token_ids": next_input_ids,
"next_token_logprobs": next_token_logprobs,
"accepted_ids": accepted_ids,
}
)
idx = len(prev_batches) - 1
for req_idx, req in enumerate(batch.requests):
new_input_length = 1
if batch.speculative_logits is not None:
new_cache_length = (
batch.cache_lengths[req_idx]
+ batch.input_lengths[req_idx]
+ accepted_ids[req_idx]
- 1
)
else:
new_cache_length = (
batch.cache_lengths[req_idx] + batch.input_lengths[req_idx]
)
batch.cache_lengths[req_idx] = new_cache_length
batch.max_input_length = max(
batch.max_input_length, new_input_length
)
batch.input_lengths[req_idx] = new_input_length
current_length = new_cache_length + new_input_length
batch.max_current_length = max(
batch.max_current_length, current_length
)
requests_to_generate.append(
{
"idx": idx,
"request_id": req.id,
"prefix_offset": batch.prefix_offsets[req_idx],
"read_offset": batch.read_offsets[req_idx],
"stopping_criteria": batch.stopping_criterias[req_idx],
"all_input_ids": batch.all_input_ids[req_idx],
"do_sample": batch.next_token_chooser.do_sample[req_idx],
"seed": batch.next_token_chooser.seeds[req_idx],
"top_n_tokens": batch.top_n_tokens[req_idx],
"top_token_ids": batch_top_token_ids[req_idx],
"top_token_logprobs": batch_top_token_logprobs[req_idx],
}
)
if prefill:
# We do not need prefill tensors anymore
batch.cu_seqlen_prefill = None
batch.prefill_cache_indices = None
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.next_token_logits = None
batch.speculative_ids = None
htorch.core.mark_step()
# Stage 2. Prepare new batch for speculative scheduling
if len(batches) > 1:
if self.bucketing_ctx is not None:
total_batch_size = 0
for b in batches:
total_batch_size += len(b)
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
total_batch_size
)
batch = self.batch_type.concatenate(
batches, padded_total_bs=padded_total_bs
)
else:
batch = self.batch_type.concatenate(batches)
else:
batch = batches[0]
prefill = batch.prefilling
if prefill:
if self.bucketing_ctx is not None:
batch.prepare_for_prefill(
self.bucketing_ctx.get_padded_prompt_seq_len(
batch.max_input_length
),
self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
self.max_total_tokens,
self.tokenizer.pad_token_id,
)
else:
batch.prepare_for_prefill(
batch.max_input_length,
len(batch),
self.max_total_tokens,
self.tokenizer.pad_token_id,
)
else:
batch.prepare_for_decode(
self.dtype,
self.use_contiguous_pa,
self.bucketing_ctx,
self.tokenizer.pad_token_id,
)
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
self.set_inputs_embeds(batch)
prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta
if adapter_meta is not None:
if batch.speculative_ids is not None:
B, speculative_length = batch.speculative_ids.shape
new_length = speculative_length + 1
adapter_indices = (
adapter_meta.adapter_indices.unsqueeze(-1)
.expand(B, new_length)
.reshape(-1)
)
adapter_segments = adapter_meta.adapter_segments * new_length
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_meta.adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_meta.segment_indices,
)
# Assign pointers to adapter weights
# TODO(travis): don't update this if indices haven't changed
adapter_data = AdapterBatchData.from_meta(
adapter_meta,
self.layer_to_adapter_weights,
prefill,
batch.prefill_head_indices,
)
else:
adapter_data = None
out, speculative_logits = self.forward(batch, adapter_data)
if prefill:
batch.next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out
)
if speculative_logits is not None:
speculative_logits = (
speculative_logits[batch.prefill_next_token_indices]
if prefill_logprobs
else speculative_logits
)
else:
prefill_logprobs = None
batch.next_token_logits = out
batch.speculative_logits = speculative_logits
# HPU->CPU sync
htorch.core.mark_step()
start_decode = time.time_ns()
for prev_batch in prev_batches:
prev_batch["next_token_logprobs"] = prev_batch[
"next_token_logprobs"
].tolist()
prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist()
prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist()
htorch.core.mark_step()
# Stage 3. Finish and return previous generations
# Results
generations: List[Generation] = []
stopped = len(requests_to_generate) > 0
# Reset max_input_length
batch.max_input_length = 0
# For each member of the batch
indexs = [0] * len(prev_batches)
idx_accept_ids = [0] * len(prev_batches)
for i, req_data in enumerate(requests_to_generate):
idx = req_data["idx"]
request_id = req_data["request_id"]
prefix_offset = req_data["prefix_offset"]
read_offset = req_data["read_offset"]
stopping_criteria = req_data["stopping_criteria"]
all_input_ids = req_data["all_input_ids"]
do_sample = req_data["do_sample"]
seed = req_data["seed"]
top_n_tokens = req_data["top_n_tokens"]
n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]]
top_token_ids = req_data["top_token_ids"]
top_token_logprobs = req_data["top_token_logprobs"]
# Append next token to all tokens
next_token_texts = []
left = 0
if n_accepted_ids > 1:
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
current_stopped = False
index = indexs[idx]
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = prev_batches[idx]["next_token_ids"][j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
if stop:
left = index + n_accepted_ids - j - 1
current_stopped = True
break
else:
current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = prev_batches[idx]["next_token_ids"][
index : index + n_accepted_ids - left
]
_next_token_logprobs = prev_batches[idx]["next_token_logprobs"][
index : index + n_accepted_ids - left
]
# Shard generations
# All generations will be appended in the rust sharded client
if request_id % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids,
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
generated_text = GeneratedText(
output_text,
stopping_criteria.current_tokens,
reason,
seed if do_sample else None,
)
else:
generated_text = None
if top_n_tokens > 0:
all_top_tokens = []
for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids
for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None
generation = Generation(
request_id,
None,
Tokens(
_next_token_ids,
_next_token_logprobs,
next_token_texts,
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text,
top_tokens,
)
generations.append(generation)
# accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids:
batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(
i, next_token_id
)
)
# Update values
indexs[idx] += n_accepted_ids
idx_accept_ids[idx] += 1
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
htorch.core.mark_step()
if stopped:
# No need to return a batch if we know that all requests stopped
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)