in backends/gaudi/server/text_generation_server/models/flash_causal_lm.py [0:0]
def warmup_hpu_graph(self, batch):
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets()
def ordering_function_min_tokens(b):
return (b[0] * b[1], b[1], b[0])
buckets = list(
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
total_batch_seq = 0.001
total_mem = 0
available_mem = prompt_available_memory
msg = (
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
)
log_master(logger.info, msg)
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, seq_len, True)
if not (
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
):
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Prefill warmup successful.\n")
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
buckets = list(
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
)
free_mem = HabanaMemoryProfiler.current_free_device_memory()
total_batch_seq = 0.001
total_mem = 0
available_mem = free_mem - self.mem_reserved
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)