in arctic_inference/vllm/model_runner.py [0:0]
def capture_model(self) -> None:
if not self.use_cuda_graph:
logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"set -O %s and ensure `use_cudagraph` was not manually set to "
"False", CompilationLevel.PIECEWISE)
return
compilation_counter.num_gpu_runner_capture_triggers += 1
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with parallel_state.graph_capture(device=self.device):
sp_size = self.parallel_config.ulysses_sequence_parallel_size
full_cg = self.full_cuda_graph
# capture original model shapes
compilation_cases = (
shape for shape in reversed(self.cudagraph_batch_sizes)
if shape * sp_size > self.shift_parallel_threshold and shape *
sp_size <= self.max_num_tokens)
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
print_cases, compilation_cases = tee(compilation_cases)
logger.info(f"original model shapes {list(print_cases)}")
compilation_cases = tqdm(
list(compilation_cases),
desc="Capturing CUDA graph shapes of original model")
for num_tokens in compilation_cases:
# We skip EPLB here since we don't want to record dummy metrics
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens * sp_size,
capture_attn_cudagraph=full_cg,
skip_eplb=True)
self._dummy_run(num_tokens * sp_size,
capture_attn_cudagraph=full_cg,
skip_eplb=True)
# Capture shift model shapes
if self.shift_model is not None:
orig_model, self.model = self.model, self.shift_model
# Reset compilation cases
compilation_cases = (
shape for shape in reversed(self.cudagraph_batch_sizes)
if shape <= self.shift_parallel_threshold
or "SwiftKV" in self.model.__class__.__name__)
# Note: We want to capture all shapes for the SwiftKV shift model.
# This is necessary since SwiftKV always uses full TP for the decode runner.
# For all other models, we only capture necessary shapes for the SP_TP mode,
# yielding less setup time.
if is_global_first_rank():
print_cases, compilation_cases = tee(compilation_cases)
logger.info(f"shift model shapes {list(print_cases)}")
compilation_cases = tqdm(
list(compilation_cases),
desc="Capturing CUDA graph shapes of shift model")
with set_shift_parallel_mode(True):
for num_tokens in compilation_cases:
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens,
capture_attn_cudagraph=full_cg,
skip_eplb=True)
self._dummy_run(num_tokens,
capture_attn_cudagraph=full_cg,
skip_eplb=True)
self.model = orig_model
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))