llm_swarm/__init__.py (241 lines of code) (raw):
from dataclasses import dataclass
import os
import subprocess
import time
from typing import List, Literal, Optional, TypeVar
from huggingface_hub import get_session
import requests
from transformers import AutoTokenizer
from itertools import cycle
from shutil import get_terminal_size
from threading import Thread
from time import sleep
import socket
DataclassT = TypeVar("DataclassT")
SLURM_LOGS_FOLDER = "slurm/logs"
@dataclass
class LLMSwarmConfig:
instances: int = 1
"""number of inference instances"""
inference_engine: Literal["tgi", "vllm"] = "tgi"
"""inference engine to use"""
slurm_template_path: str = "templates/tgi_h100.template.slurm"
"""path to slurm template"""
model: str = "mistralai/Mistral-7B-Instruct-v0.1"
"""the name of the HF model or path to use"""
revision: str = "main"
"""the revision of the model to use"""
gpus: int = 8
"""number of gpus per instance"""
load_balancer_template_path: str = "templates/nginx.template.conf"
"""path to load balancer template"""
per_instance_max_parallel_requests: int = 500
"""maximum number of parallel requests per instance"""
debug_endpoint: Optional[str] = None
"""endpoint to use for debugging (e.g. http://localhost:13120)"""
def run_command(command: str):
print(f"running {command}")
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
output, errors = process.communicate()
return_code = process.returncode
assert return_code == 0, f"Command failed with error: {errors.decode('utf-8')}"
return output.decode("utf-8").strip()
def is_job_running(job_id: str):
"""Given job id, check if the job is in eunning state (needed to retrieve hostname from logs)"""
command = "squeue --me --states=R | awk '{print $1}' | tail -n +2"
my_running_jobs = subprocess.run(command, shell=True, text=True, capture_output=True).stdout.splitlines()
return job_id in my_running_jobs
def make_sure_jobs_are_still_running(job_ids: List[str]):
if job_ids:
for job_id in job_ids:
if not is_job_running(job_id):
slumr_log_path = os.path.join(SLURM_LOGS_FOLDER, f"llm-swarm_{job_id}.out")
print(f"\n❌ Failed! Job {job_id} is not running; checkout {slumr_log_path} ")
raise
def get_unused_port(start=50000, end=65535):
for port in range(start, end + 1):
try:
sock = socket.socket()
sock.bind(("", port))
sock.listen(1)
sock.close()
return port
except OSError:
continue
raise IOError("No free ports available in range {}-{}".format(start, end))
def test_generation(endpoint):
headers = {
"Content-Type": "application/json",
}
data = {
"inputs": "What is Deep Learning?",
"parameters": {
"max_new_tokens": 200,
},
}
requests.post(endpoint, headers=headers, json=data)
print("✅ test generation")
class Loader:
def __init__(self, desc="Loading...", end="✅ Done!", failed="❌ Aborted!", timeout=0.1):
"""
A loader-like context manager
Modified from https://stackoverflow.com/a/66558182/6611317
Args:
desc (str, optional): The loader's description. Defaults to "Loading...".
end (str, optional): Final print. Defaults to "Done!".
failed (str, optional): Final print on failure. Defaults to "Aborted!".
timeout (float, optional): Sleep time between prints. Defaults to 0.1.
"""
self.desc = desc
self.end = end + " " + self.desc
self.failed = failed + " " + self.desc
self.timeout = timeout
self._thread = Thread(target=self._animate, daemon=True)
self.steps = ["⢿", "⣻", "⣽", "⣾", "⣷", "⣯", "⣟", "⡿"]
self.done = False
def start(self):
self._thread.start()
return self
def _animate(self):
try:
for c in cycle(self.steps):
if self.done:
break
print(f"\r{c} {self.desc}", flush=True, end="")
sleep(self.timeout)
except KeyboardInterrupt:
self.stop()
print("KeyboardInterrupt by user")
def __enter__(self):
self.start()
def stop(self):
self.done = True
cols = get_terminal_size((80, 20)).columns
print("\r" + " " * cols, end="", flush=True)
print(f"\r{self.end}", flush=True)
def __exit__(self, exc_type, exc_value, tb):
if exc_type is None:
self.stop()
else:
self.done = True
cols = get_terminal_size((80, 20)).columns
print("\r" + " " * cols, end="", flush=True)
print(f"\r{self.failed}", flush=True)
def get_endpoints(endpoint_path: str, instances: int = 1, job_ids: Optional[List[str]] = None) -> List[str]:
"""Return list of endpoints from either a file or a comma separated string.
It also checks if the endpoints are reachable.
Args:
endpoint_path (str): path to file containing endpoints or comma separated string
instances (int, optional): number of instances. Defaults to 1.
Returns:
List[str]: list of endpoints (e.g. ["http://26.0.154.245:13120"])
"""
trying = True
with Loader(f"Waiting for {endpoint_path} to be created"):
while trying:
try:
endpoints = open(endpoint_path).read().splitlines()
assert (
len(endpoints) == instances
), f"#endpoints {len(endpoints)} doesn't match #instances {instances}" # could read an empty file
# due to race condition (slurm writing & us reading)
trying = False
except (OSError, AssertionError):
make_sure_jobs_are_still_running(job_ids)
sleep(1)
print("obtained endpoints", endpoints)
for endpoint in endpoints:
with Loader(f"Waiting for {endpoint} to be reachable"):
connected = False
while not connected:
try:
get_session().get(f"{endpoint}/health")
print(f"\nConnected to {endpoint}")
connected = True
except requests.exceptions.ConnectionError:
make_sure_jobs_are_still_running(job_ids)
sleep(1)
return endpoints
class LLMSwarm:
def __init__(self, config: LLMSwarmConfig) -> None:
self.config = config
self.cleaned_up = False
self.tokenizer = AutoTokenizer.from_pretrained(config.model, revision=config.revision)
os.makedirs(SLURM_LOGS_FOLDER, exist_ok=True)
def start(self):
# if debug endpoint is provided, use it as is
if self.config.debug_endpoint:
self.endpoint = self.config.debug_endpoint
if self.config.inference_engine == "vllm":
self.endpoint = f"{self.config.debug_endpoint}/generate"
if self.config.debug_endpoint.startswith("https://api-inference.huggingface.co/"):
self.suggested_max_parallel_requests = 40
else:
self.suggested_max_parallel_requests = self.config.per_instance_max_parallel_requests * self.config.instances
return
self.suggested_max_parallel_requests = self.config.per_instance_max_parallel_requests * self.config.instances
with open(self.config.slurm_template_path) as f:
slurm_template = f.read()
# customize slurm template
self.filename = f"{self.config.inference_engine}_{int(time.time())}"
slurm_path = os.path.join("slurm", f"{self.filename}_{self.config.inference_engine}.slurm")
slurm_host_path = os.path.join("slurm", f"{self.filename}_host_{self.config.inference_engine}.txt")
slurm_template = slurm_template.replace(r"{{slurm_hosts_path}}", slurm_host_path)
slurm_template = slurm_template.replace(r"{{model}}", self.config.model)
slurm_template = slurm_template.replace(r"{{revision}}", self.config.revision)
slurm_template = slurm_template.replace(r"{{gpus}}", str(self.config.gpus))
slurm_template = slurm_template.replace(r"{{model_max_length}}", str(min(self.tokenizer.model_max_length, 32768)))
slurm_template = slurm_template.replace(r"{{model_input_length}}", str(min(self.tokenizer.model_max_length - 100, 32768 - 100))) # `model_input_length` needs to be smaller than `model_max_length`
with open(slurm_path, "w") as f:
f.write(slurm_template)
# start inference instances
self.job_ids = [run_command(f"sbatch --parsable {slurm_path}") for _ in range(self.config.instances)]
print(f"Slurm Job ID: {self.job_ids}")
print(f"📖 Slurm hosts path: {slurm_host_path}")
self.container_id = None
try:
# ensure job is running
for job_id in self.job_ids:
with Loader(f"Waiting for {job_id} to be created"):
while not is_job_running(job_id):
sleep(1)
slumr_log_path = os.path.join(SLURM_LOGS_FOLDER, f"llm-swarm_{job_id}.out")
print(f"📖 Slurm log path: {slumr_log_path}")
# retrieve endpoints
self.endpoints = get_endpoints(slurm_host_path, self.config.instances, self.job_ids)
print(f"Endpoints running properly: {self.endpoints}")
# warm up endpoints
for endpoint in self.endpoints:
test_generation(endpoint)
if len(self.endpoints) == 1:
print(f"🔥 endpoint ready {self.endpoints[0]}")
self.endpoint = self.endpoints[0]
else:
# run a load balancer
with open(self.config.load_balancer_template_path) as f:
# templates/nginx.template.conf
load_balancer_template = f.read()
servers = "\n".join([f"server {endpoint.replace('http://', '')};" for endpoint in self.endpoints])
unused_port = get_unused_port()
load_balancer_template = load_balancer_template.replace(r"{{servers}}", servers)
load_balancer_template = load_balancer_template.replace(r"{{port}}", str(unused_port))
load_balancer_path = os.path.join("slurm", f"{self.filename}_load_balancer.conf")
with open(load_balancer_path, "w") as f:
f.write(load_balancer_template)
load_balance_endpoint = f"http://localhost:{unused_port}"
command = f"sudo docker run -d -p {unused_port}:{unused_port} --network host -v $(pwd)/{load_balancer_path}:/etc/nginx/nginx.conf nginx"
load_balance_endpoint_connected = False
# run docker streaming output while we validate the endpoints
self.container_id = run_command(command)
last_line = 0
while True:
logs = run_command(f"sudo docker logs {self.container_id}")
lines = logs.split("\n")
for line in lines[last_line:]:
print(line)
last_line = len(lines)
if not load_balance_endpoint_connected:
try:
get_session().get(f"{load_balance_endpoint}/health")
print(f"🔥 endpoint ready {load_balance_endpoint}")
load_balance_endpoint_connected = True
self.endpoint = load_balance_endpoint
break
except requests.exceptions.ConnectionError:
sleep(1)
if self.config.inference_engine == "vllm":
self.endpoint = f"{self.endpoint}/generate"
except (KeyboardInterrupt, Exception):
self.cleanup()
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.cleanup()
def cleanup(self, signum=None, frame=None):
if self.config.debug_endpoint:
return
if self.cleaned_up:
return
for job_id in self.job_ids:
run_command(f"scancel {job_id}")
print("inference instances terminated")
if self.container_id:
run_command(f"sudo docker kill {self.container_id}")
print("docker process terminated")
self.cleaned_up = True
if __name__ == "__main__":
with LLMSwarm(
LLMSwarmConfig(
instances=3,
inference_engine="tgi",
slurm_template_path="templates/tgi_h100.template.slurm",
load_balancer_template_path="templates/nginx.template.conf",
)
) as llm_swarm:
while True:
input("Press Enter to EXIT...")
break