backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py (546 lines of code) (raw):
import torch
import numpy as np
from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request
from io import BytesIO
from PIL import Image
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
PreTrainedTokenizerBase,
)
from text_generation_server.models.flash_causal_lm import (
prepare_for_decode,
)
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
FlashVlmCausalLM,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.layers.attention import (
Seqlen,
trim_seqlen_metadata,
_async_h2d_tensor_copy,
)
import habana_frameworks.torch as htorch
from loguru import logger
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.import_utils import (
synchronize,
)
import torch.nn.functional as F
from text_generation_server.utils.log import log_master
import time
import os
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__)
@dataclass
class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
image_indices: List[int] = 42
aspect_ratio_ids: Optional[torch.Tensor] = None
aspect_ratio_mask: Optional[torch.Tensor] = None
cross_attention_states: Optional[torch.Tensor] = None
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
):
super(FlashVlmCausalLMBatch, self).prepare_for_prefill(
max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches, padded_total_bs: int = 0):
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
batch.pixel_values = None
batch.pixel_attention_mask = None
offset = 0
image_indices = []
attention_states = []
for b in batches:
if b.cross_attention_states is not None:
attention_states.append(b.cross_attention_states)
image_indices.extend([i + offset for i in b.image_indices])
offset += len(b.image_indices)
if len(attention_states) > 0:
assert len(image_indices) > 0
batch.cross_attention_states = torch.cat(attention_states, dim=0)
batch.image_indices = image_indices
else:
batch.cross_attention_states = None
batch.image_indices = []
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
assert self.image_indices is not None
batch = super(FlashVlmCausalLMBatch, self).filter(request_ids)
assert self.image_indices is not None
indices = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
offset = 0
new_image_indices = []
prev_i = None
for i in self.image_indices:
if i in indices:
new_image_indices.append(offset)
if i != prev_i:
offset += 1
prev_i = i
batch.image_indices = new_image_indices
if len(new_image_indices) > 0:
assert max(new_image_indices) < self.cross_attention_states.shape[0]
assert offset <= self.cross_attention_states.shape[0]
batch.cross_attention_states = self.cross_attention_states[
new_image_indices
]
else:
batch.cross_attention_states = None
batch.pixel_values = None
return batch
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config
):
image_inputs = []
texts = []
image_indices = []
batch_tokenized_inputs = []
for i, r in enumerate(requests):
# Each input is encoded into a list, where each element of this input list is either a string or a URL
curr_text = ""
curr_image = None
curr_i = None
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
curr_text += chunk.text
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# TODO unsure about BOS
curr_text += "<|image|>"
image_input = processor.image_processor(image, return_tensors="pt")
curr_image = image_input
curr_i = i
# image_inputs.append(image_input)
# image_indices.append(i)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
texts.append(curr_text)
if curr_image is not None:
image_inputs.append(curr_image)
image_indices.append(curr_i)
input_ids = tokenizer(
curr_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
batch_tokenized_inputs.append(input_ids)
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "aspect_ratio_ids" in image_input:
new_image_inputs["aspect_ratio_ids"] = torch.cat(
[img["aspect_ratio_ids"] for img in image_inputs], dim=0
)
if "aspect_ratio_mask" in image_input:
new_image_inputs["aspect_ratio_mask"] = torch.cat(
[img["aspect_ratio_mask"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
image_inputs["image_indices"] = image_indices
else:
image_inputs = None
if image_inputs is not None:
assert len(image_indices) == image_inputs["pixel_values"].shape[0]
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "FlashVlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
max=config.text_config.vocab_size - 1
)
if isinstance(batch.input_ids, list):
if len(batch) > 1:
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else:
input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(
device=device, dtype=dtype
)
batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(device=device)
batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to(
device=device
)
batch.image_indices = image_inputs["image_indices"]
else:
batch.pixel_values = None
batch.aspect_ratio_ids = None
batch.aspect_ratio_mask = None
batch.image_indices = []
assert batch.image_indices is not None
return batch
def generate_cross_attention_states(
cross_attention_states, image_indices, input_lengths, pad_seq_len, prefilling
):
if cross_attention_states is None:
return None, None, None
indices_list = []
if prefilling:
for i in image_indices:
indices_list.append(torch.arange(pad_seq_len * i, pad_seq_len * (i + 1)))
indices = torch.cat(indices_list, dim=0)
else:
indices = image_indices[:]
return indices, input_lengths.index_select(0, image_indices)
class FlashMllamaCausalLM(FlashVlmCausalLM):
def set_inputs_embeds(self, batch):
# Set the input embeddings to None, as we are using the input_ids for the model
batch.inputs_embeds = None
def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
):
input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
block_tables = []
slots = []
start_idx = 0
# fetch the last blocked to warmup block num
for i in range(batch_size):
block_array = list(range(start_idx, start_idx + blocks[i]))
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
input_lengths = torch.ones(batch_size, dtype=torch.int32)
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
hpu_attention_meta = prepare_for_decode(
self.dtype,
self.use_contiguous_pa,
self.device,
slots,
block_tables,
batch_size,
bucketing_ctx=None,
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
image_indices = torch.tensor(batch.image_indices)
image_indices = image_indices.repeat(batch_size)
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states, image_indices, input_lengths, 1, False
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
False, hpu_attention_meta.block_list.shape[0], batch_size
)
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=_async_h2d_tensor_copy(slots_tensor),
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None,
adapter_data=None,
cross_attention_states=cross_attention_states,
indices=_async_h2d_tensor_copy(indices),
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
**kwargs,
)
def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch
):
input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat(
batch_size
)
position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat(
batch_size
)
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1)
slot_acc = []
for i in range(batch_size):
slots = []
for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
slot_acc.extend(slots[:prompt_len])
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype)
input_lengths = (
torch.ones(
batch_size,
dtype=torch.int32,
)
* prompt_len
)
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
lm_head_indices = input_lengths - 1
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
image_indices = torch.tensor(batch.image_indices)
image_indices = image_indices.repeat(batch_size)
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states, image_indices, input_lengths, prompt_len, True
)
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
True, prompt_len, batch_size
)
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=self.kv_cache,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=None,
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
adapter_data=None,
cross_attention_states=cross_attention_states,
indices=_async_h2d_tensor_copy(indices),
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
**kwargs,
)
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets()
def ordering_function_min_tokens(b):
return (b[0] * b[1], b[1], b[0])
buckets = list(
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
graph_free_mem
total_batch_seq = 0.001
total_mem = 0
available_mem = prompt_available_memory
msg = (
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
)
log_master(logger.info, msg)
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, seq_len, True)
if not (
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
):
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Prefill warmup successful.\n")
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
buckets = list(
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
)
free_mem = HabanaMemoryProfiler.current_free_device_memory()
total_batch_seq = 0.001
total_mem = 0
available_mem = free_mem - self.mem_reserved
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def forward(
self,
batch: FlashMllamaCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
)
batch.cross_attention_states = cross_attention_states
cross_attention_states = batch.cross_attention_states
kwargs = {}
if htorch.utils.internal.is_lazy():
batch_size = input_lengths.shape[0]
seqlen = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, seqlen, batch_size
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
else:
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots
slots = slots_pad
orig_bs = len(batch)
padded_bs = batch.input_lengths_tensor.shape[0]
padded_input_len = input_ids.view(padded_bs, -1).shape[-1]
image_indices = torch.tensor(batch.image_indices)
if cross_attention_states is not None:
cross_attention_states = F.pad(
cross_attention_states,
(0, 0, 0, 0, 0, (padded_bs - orig_bs)),
value=0,
)
if len(image_indices) != 0:
pad_indices = torch.arange(orig_bs, padded_bs)
image_indices = torch.cat((image_indices, pad_indices), dim=0)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states,
image_indices,
input_lengths,
padded_input_len,
batch.prefilling,
)
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
# TODO list
adapter_data=None,
cross_attention_states=cross_attention_states,
indices=_async_h2d_tensor_copy(indices),
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
**kwargs,
)
if batch.pixel_values is not None:
batch.pixel_values = None
return logits, speculative_logits