in chatlearn/models/vllm_module.py [0:0]
def generate_vllm(self, query, is_eval, iteration=0): # pylint: disable=unused-argument
num_gpu_blocks, num_cpu_blocks = self.profile_cache_blocks()
num_blocks = torch.tensor([num_gpu_blocks, num_cpu_blocks], device='cuda')
torch.distributed.all_reduce(num_blocks, op=torch.distributed.ReduceOp.MIN)
min_gpu_blocks = num_blocks[0].item()
min_cpu_blocks = num_blocks[1].item()
self.set_cache_config(min_gpu_blocks, min_cpu_blocks)
if self.is_last_rank():
self.build_scheduler()
self.reinit_cache_engine()
# add requests of current episode to vllm scheduler
if self.is_last_rank():
self._add_request(query, is_eval=is_eval)
step_outputs = True
while step_outputs:
schedule_query = None
if self.is_last_rank():
# support multi step schedule.
virtual_engine = 0
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = False
ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
# Schedule iteration
scheduler_outputs = self.schedule()
seq_group_metadata_list = scheduler_outputs["seq_group_metadata_list"]
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
if (self.scheduler_config.is_multi_step
and scheduler_outputs["num_lookahead_slots"] > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
schedule_query = scheduler_outputs
if len(scheduler_outputs) == 0:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
schedule_query = broadcast_var_object_dict(schedule_query, torch.distributed.get_world_size()-1)
output = self.execute_step(schedule_query)
if self.is_last_rank():
step_outputs = bool(output)
signal_tensor = torch.tensor(step_outputs, device='cuda')
torch.distributed.broadcast(signal_tensor, torch.distributed.get_world_size()-1)
else:
signal_tensor = torch.tensor(True, device='cuda')
torch.distributed.broadcast(signal_tensor, torch.distributed.get_world_size()-1)
step_outputs = signal_tensor.item()
if self.is_last_rank():
self.outputs = sorted(self.outputs, key=lambda x: int(x.request_id))
return self.outputs