in trl/scripts/vllm_serve.py [0:0]
def main(script_args: ScriptArguments):
if not is_fastapi_available():
raise ImportError(
"FastAPI is required to run the vLLM serve script. Please install it using `pip install fastapi`."
)
if not is_pydantic_available():
raise ImportError(
"Pydantic is required to run the vLLM serve script. Please install it using `pip install pydantic`."
)
if not is_uvicorn_available():
raise ImportError(
"Uvicorn is required to run the vLLM serve script. Please install it using `pip install uvicorn`."
)
if not is_vllm_available():
raise ImportError("vLLM is required to run the vLLM serve script. Please install it using `pip install vllm`.")
# Spawn dp workers, and setup pipes for communication
master_port = get_open_port()
connections = []
processes = []
for data_parallel_rank in range(script_args.data_parallel_size):
parent_connection, child_connection = Pipe()
process = Process(target=llm_worker, args=(script_args, data_parallel_rank, master_port, child_connection))
process.start()
connections.append(parent_connection)
processes.append(process)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Wait for all workers to send "ready"
ready_connections = set()
while len(ready_connections) < script_args.data_parallel_size:
for connection in connections:
msg = connection.recv()
if isinstance(msg, dict) and msg.get("status") == "ready":
ready_connections.add(connection)
yield
# Wait for processes to terminate
for process in processes:
process.join(timeout=10) # Wait for 10 seconds for the process to terminate
if process.is_alive():
logger.warning(f"Process {process} is still alive after 10 seconds, attempting to terminate...")
process.terminate()
process.join() # ensure process termination after calling terminate()
app = FastAPI(lifespan=lifespan)
# Define the endpoints for the model server
@app.get("/health/")
async def health():
"""
Health check endpoint to verify that the server is running.
"""
return {"status": "ok"}
@app.get("/get_world_size/")
async def get_world_size():
"""
Retrieves the world size of the LLM engine, which is `tensor_parallel_size * data_parallel_size`.
Returns:
`dict`:
A dictionary containing the world size.
Example response:
```json
{"world_size": 8}
```
"""
return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size}
class GenerateRequest(BaseModel):
prompts: list[str]
n: int = 1
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
min_p: float = 0.0
max_tokens: int = 16
guided_decoding_regex: Optional[str] = None
generation_kwargs: dict = field(default_factory=dict)
class GenerateResponse(BaseModel):
completion_ids: list[list[int]]
@app.post("/generate/", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
"""
Generates completions for the provided prompts.
Args:
request (`GenerateRequest`):
- `prompts` (list of `str`): A list of prompts (text strings) for the model to generate completions.
- `n` (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt.
- `repetition_penalty` (`float`, *optional*, defaults to `1.0`): Repetition penalty to apply during generation.
- `temperature` (`float`, *optional*, defaults to `1.0`): Temperature for sampling. Higher values lead to more random outputs.
- `top_p` (`float`, *optional*, defaults to `1.0`): Top-p (nucleus) sampling parameter. It controls the diversity of the generated text.
- `top_k` (`int`, *optional*, defaults to `-1`): Top-k sampling parameter. If set to `-1`, it disables top-k sampling.
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each completion.
- `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the model will only generate tokens that match this regex pattern.
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains keys that conflict with the other parameters, they will override them.
Returns:
`GenerateResponse`:
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion.
Example request:
```json
{"prompts": ["Hello world", "What is AI?"]}
```
Example response:
```json
{"completion_ids": [[101, 102, 103], [201, 202, 203]]}
```
"""
# Guided decoding, if enabled
if request.guided_decoding_regex is not None:
guided_decoding = GuidedDecodingParams(backend="outlines", regex=request.guided_decoding_regex)
else:
guided_decoding = None
generation_kwargs = {
"n": request.n,
"repetition_penalty": request.repetition_penalty,
"temperature": request.temperature,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"max_tokens": request.max_tokens,
"guided_decoding": guided_decoding,
}
generation_kwargs.update(request.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)
# Evenly distribute prompts across DP ranks
chunked_prompts = chunk_list(request.prompts, script_args.data_parallel_size)
# Send the prompts to each worker
for connection, prompts in zip(connections, chunked_prompts):
# When the number of prompts is less than data_parallel_size, some workers will receive empty prompts.
# However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply
# with vLLM's requirement, and we later ignore the result.
if not prompts:
prompts = ["<placeholder>"]
kwargs = {"prompts": prompts, "sampling_params": sampling_params}
connection.send({"type": "call", "method": "generate", "kwargs": kwargs})
# Receive results
all_outputs = [connection.recv() for connection in connections]
# Handle empty prompts (see above)
all_outputs = [output for output, prompts in zip(all_outputs, chunked_prompts) if prompts]
# Flatten and combine all results
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
return {"completion_ids": completion_ids}
class InitCommunicatorRequest(BaseModel):
host: str
port: int
world_size: int
@app.post("/init_communicator/")
async def init_communicator(request: InitCommunicatorRequest):
"""
Initializes the communicator for synchronizing model weights between a client and multiple server workers.
Args:
request (`InitCommunicatorRequest`):
- `host` (`str`): Hostname or IP address of the master node.
- `port` (`int`): Port number to be used for communication.
- `world_size` (`int`): Total number of participating processes in the group.
"""
world_size = script_args.tensor_parallel_size * script_args.data_parallel_size + 1
# The function init_communicator is called this way: init_communicator(host, port, world_size)
# So with collective_rpc we need to call it this way:
# llm.collective_rpc(method="init_communicator", args=(host, port, world_size))
kwargs = {"method": "init_communicator", "args": (request.host, request.port, world_size)}
for connection in connections:
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs})
return {"message": "Request received, initializing communicator"}
class UpdateWeightsRequest(BaseModel):
name: str
dtype: str
shape: list[int]
@app.post("/update_named_param/")
async def update_named_param(request: UpdateWeightsRequest):
"""
Updates the model weights with the provided tensor.
Once this endpoint is called, the client process should broadcast the updated weights to all server workers.
Args:
request (`UpdateWeightsRequest`):
- `name` (`str`): Name of the weight tensor being updated.
- `dtype` (`str`): Data type of the weight tensor (e.g., `"torch.float32"`).
- `shape` (list of `int`): Shape of the weight
"""
# The function update_named_param is called this way: update_named_param("name", torch.float32, (10, 10))
# So with collective_rpc we need to call it this way:
# llm.collective_rpc("update_named_param", args=("name", torch.float32, (10, 10)))
dtype = torch.__getattribute__(request.dtype.split(".")[-1])
kwargs = {"method": "update_named_param", "args": (request.name, dtype, tuple(request.shape))}
for connection in connections:
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs})
return {"message": "Request received, updating named parameter"}
@app.post("/reset_prefix_cache/")
async def reset_prefix_cache():
"""
Resets the prefix cache for the model.
"""
for connection in connections:
connection.send({"type": "call", "method": "reset_prefix_cache"})
# Wait for and collect all results
all_outputs = [connection.recv() for connection in connections]
success = all(output for output in all_outputs)
return {"message": "Request received, resetting prefix cache status: " + str(success)}
@app.post("/close_communicator/")
async def close_communicator():
"""
Closes the weight update group and cleans up associated resources.
"""
kwargs = {"method": "close_communicator"}
for connection in connections:
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs})
return {"message": "Request received, closing communicator"}
# Start the server
uvicorn.run(app, host=script_args.host, port=script_args.port, log_level=script_args.log_level)