in chatlearn/models/vllm_module.py [0:0]
def process_model_outputs(self, output, seq_group_metadata_list=None):
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
step_outputs = self._process_model_outputs(output, self.scheduler_outputs)
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1:
step_outputs = self._process_model_outputs(
output, self.scheduler_outputs.scheduled_seq_groups,
self.scheduler_outputs.ignored_seq_groups, self.seq_group_metadata_list)
else:
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
# if self.is_last_rank():virtual_engine
virtual_engine = 0
allow_async_output_proc = False
ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
for seq_group in seq_group_metadata_list:
seq_group.finish_step()
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps.
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# Add results to the output_queue
ctx.append_output(outputs=output,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=self.scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True,
is_first_step_output=is_first_step_output)
self._process_model_outputs(ctx=ctx)
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
step_outputs = ctx.request_outputs
else:
# Multi-step case
step_outputs = ctx.request_outputs
else:
raise RuntimeError(f"Unsupported vllm version {CURRENT_VLLM_VERSION}, expect one of {list(VLLMVersion)}")
done = 0
for out in step_outputs:
if out.finished:
self.outputs.append(out)
done += 1
self.pbar.update(1)
self.num_requests -= done
if self.num_requests <= 0:
self.pbar.close()
if self._log_metrics:
self.log_metrics_stats(done)
return self.num_requests