def warmup_hpu_graph()

in backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py [0:0]


    def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
        prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
        free_mem = HabanaMemoryProfiler.current_free_device_memory()
        graph_free_mem = free_mem - self.mem_reserved
        graph_free_mem = self.align_workers(
            graph_free_mem, torch.distributed.ReduceOp.MIN
        )
        prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
        decode_available_memory = graph_free_mem - prompt_available_memory
        msg = (
            f"Using {format_bytes(graph_free_mem)}"
            f"/{format_bytes(free_mem)} "
            "of free device memory for HPUGraphs, "
            f"{format_bytes(prompt_available_memory)} for prompt and "
            f"{format_bytes(decode_available_memory)} for decode "
            f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
        )
        log_master(logger.info, msg)
        start_time = time.time()
        warmup_shape_count = 0
        warmup_times = 3
        self.bucketing_ctx.generate_prompt_buckets()

        def ordering_function_min_tokens(b):
            return (b[0] * b[1], b[1], b[0])

        buckets = list(
            sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
        )
        graph_free_mem
        total_batch_seq = 0.001
        total_mem = 0
        available_mem = prompt_available_memory
        msg = (
            f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
            f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
        )
        log_master(logger.info, msg)
        for i, (batch_size, seq_len) in enumerate(buckets):
            if batch_size * seq_len > self.max_batch_prefill_tokens:
                continue
            # Graph memory usage is proportional to seq dimension in a batch
            batch_seq = batch_size * seq_len
            mem_estimate = batch_seq / total_batch_seq * total_mem
            graphed_bucket = (batch_size, seq_len, True)
            if not (
                mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
            ):
                if graphed_bucket not in self.graphed_buckets:
                    self.graphed_buckets.add(graphed_bucket)
            warmup_shape_count += 1
            self.log_warmup(True, i, len(buckets), batch_size, seq_len)
            with HabanaMemoryProfiler() as mem_prof:
                for index in range(warmup_times):
                    self.warmup_prefill(seq_len, batch_size, batch)
                    synchronize(self.device)
            used_mem = self.align_workers(
                mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
            )
            if graphed_bucket in self.graphed_buckets:
                available_mem -= used_mem
                total_mem += used_mem
                total_batch_seq += batch_seq

        log_master(logger.info, "Prefill warmup successful.\n")

        def ordering_function_max_bs(b):
            return (-b[0], b[1])

        self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
        buckets = list(
            sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
        )
        free_mem = HabanaMemoryProfiler.current_free_device_memory()
        total_batch_seq = 0.001
        total_mem = 0
        available_mem = free_mem - self.mem_reserved
        log_master(
            logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
        )
        for i, (batch_size, block_num) in enumerate(buckets):
            if batch_size > block_num:
                continue
            # Graph memory usage is proportional to seq dimension in a batch
            batch_seq = batch_size
            mem_estimate = batch_seq / total_batch_seq * total_mem
            graphed_bucket = (batch_size, block_num, False)
            if not mem_estimate >= available_mem:
                if graphed_bucket not in self.graphed_buckets:
                    self.graphed_buckets.add(graphed_bucket)
            warmup_shape_count += 1
            self.log_warmup(False, i, len(buckets), batch_size, block_num)
            with HabanaMemoryProfiler() as mem_prof:
                for index in range(warmup_times):
                    self.warmup_decode(batch_size, block_num, batch)
                    synchronize(self.device)
            used_mem = self.align_workers(
                mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
            )
            if graphed_bucket in self.graphed_buckets:
                available_mem -= used_mem
                total_mem += used_mem
                total_batch_seq += batch_seq

        log_master(logger.info, "Decode warmup successful.\n")

        log_master(
            logger.info,
            f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
        )