def capture_model()

in arctic_inference/vllm/model_runner.py [0:0]


    def capture_model(self) -> None:
        if not self.use_cuda_graph:
            logger.warning(
                "Skipping CUDA graph capture. To turn on CUDA graph capture, "
                "set -O %s and ensure `use_cudagraph` was not manually set to "
                "False", CompilationLevel.PIECEWISE)
            return

        compilation_counter.num_gpu_runner_capture_triggers += 1

        start_time = time.perf_counter()
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
        with parallel_state.graph_capture(device=self.device):
            sp_size = self.parallel_config.ulysses_sequence_parallel_size
            full_cg = self.full_cuda_graph
            # capture original model shapes
            compilation_cases = (
                shape for shape in reversed(self.cudagraph_batch_sizes)
                if shape * sp_size > self.shift_parallel_threshold and shape *
                sp_size <= self.max_num_tokens)
            # Only rank 0 should print progress bar during capture
            if is_global_first_rank():
                print_cases, compilation_cases = tee(compilation_cases)
                logger.info(f"original model shapes {list(print_cases)}")
                compilation_cases = tqdm(
                    list(compilation_cases),
                    desc="Capturing CUDA graph shapes of original model")
            for num_tokens in compilation_cases:
                # We skip EPLB here since we don't want to record dummy metrics
                for _ in range(self.vllm_config.compilation_config.
                               cudagraph_num_of_warmups):
                    self._dummy_run(num_tokens * sp_size,
                                    capture_attn_cudagraph=full_cg,
                                    skip_eplb=True)
                self._dummy_run(num_tokens * sp_size,
                                capture_attn_cudagraph=full_cg,
                                skip_eplb=True)

            # Capture shift model shapes
            if self.shift_model is not None:
                orig_model, self.model = self.model, self.shift_model
                # Reset compilation cases
                compilation_cases = (
                    shape for shape in reversed(self.cudagraph_batch_sizes)
                    if shape <= self.shift_parallel_threshold
                    or "SwiftKV" in self.model.__class__.__name__)
                # Note: We want to capture all shapes for the SwiftKV shift model.
                # This is necessary since SwiftKV always uses full TP for the decode runner.
                # For all other models, we only capture necessary shapes for the SP_TP mode,
                # yielding less setup time.
                if is_global_first_rank():
                    print_cases, compilation_cases = tee(compilation_cases)
                    logger.info(f"shift model shapes {list(print_cases)}")
                    compilation_cases = tqdm(
                        list(compilation_cases),
                        desc="Capturing CUDA graph shapes of shift model")
                with set_shift_parallel_mode(True):
                    for num_tokens in compilation_cases:
                        for _ in range(self.vllm_config.compilation_config.
                                       cudagraph_num_of_warmups):
                            self._dummy_run(num_tokens,
                                            capture_attn_cudagraph=full_cg,
                                            skip_eplb=True)
                        self._dummy_run(num_tokens,
                                        capture_attn_cudagraph=full_cg,
                                        skip_eplb=True)
                self.model = orig_model

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / (1 << 30))