in chatlearn/models/vllm_module.py [0:0]
def execute_step(self, data):
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
output = self.worker.execute_model(
data["seq_group_metadata_list"],
data["blocks_to_swap_in"],
data["blocks_to_swap_out"],
data["blocks_to_copy"]
)
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1:
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=data["seq_group_metadata_list"],
blocks_to_swap_in=data["blocks_to_swap_in"],
blocks_to_swap_out=data["blocks_to_swap_out"],
blocks_to_copy=data["blocks_to_copy"],
num_lookahead_slots=data["num_lookahead_slots"],
running_queue_size=data["running_queue_size"],
finished_requests_ids=data["finished_requests_ids"]
)
output = self.worker.execute_model(execute_model_req=execute_model_req)
else:
if len(data) > 0:
# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0.
virtual_engine = 0
# These are cached outputs from previous iterations. None if on first
# iteration
seq_group_metadata_list = data["seq_group_metadata_list"]
allow_async_output_proc = False
assert seq_group_metadata_list is not None
finished_requests_ids = data["finished_requests_ids"]
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids = None
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=data["blocks_to_swap_in"],
blocks_to_swap_out=data["blocks_to_swap_out"],
blocks_to_copy=data["blocks_to_copy"],
num_lookahead_slots=data["num_lookahead_slots"],
running_queue_size=data["running_queue_size"],
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
output = self.worker.execute_model(execute_model_req=execute_model_req)
else:
# No outputs in this case
output = []
else:
raise RuntimeError(f"Unsupported vllm version {CURRENT_VLLM_VERSION}, expect one of {list(VLLMVersion)}")
if self.is_last_rank() and hasattr(self, "scheduler_outputs"):
return self.process_model_outputs(output, seq_group_metadata_list=data["seq_group_metadata_list"])
return output