in llm_swarm/__init__.py [0:0]
def start(self):
# if debug endpoint is provided, use it as is
if self.config.debug_endpoint:
self.endpoint = self.config.debug_endpoint
if self.config.inference_engine == "vllm":
self.endpoint = f"{self.config.debug_endpoint}/generate"
if self.config.debug_endpoint.startswith("https://api-inference.huggingface.co/"):
self.suggested_max_parallel_requests = 40
else:
self.suggested_max_parallel_requests = self.config.per_instance_max_parallel_requests * self.config.instances
return
self.suggested_max_parallel_requests = self.config.per_instance_max_parallel_requests * self.config.instances
with open(self.config.slurm_template_path) as f:
slurm_template = f.read()
# customize slurm template
self.filename = f"{self.config.inference_engine}_{int(time.time())}"
slurm_path = os.path.join("slurm", f"{self.filename}_{self.config.inference_engine}.slurm")
slurm_host_path = os.path.join("slurm", f"{self.filename}_host_{self.config.inference_engine}.txt")
slurm_template = slurm_template.replace(r"{{slurm_hosts_path}}", slurm_host_path)
slurm_template = slurm_template.replace(r"{{model}}", self.config.model)
slurm_template = slurm_template.replace(r"{{revision}}", self.config.revision)
slurm_template = slurm_template.replace(r"{{gpus}}", str(self.config.gpus))
slurm_template = slurm_template.replace(r"{{model_max_length}}", str(min(self.tokenizer.model_max_length, 32768)))
slurm_template = slurm_template.replace(r"{{model_input_length}}", str(min(self.tokenizer.model_max_length - 100, 32768 - 100))) # `model_input_length` needs to be smaller than `model_max_length`
with open(slurm_path, "w") as f:
f.write(slurm_template)
# start inference instances
self.job_ids = [run_command(f"sbatch --parsable {slurm_path}") for _ in range(self.config.instances)]
print(f"Slurm Job ID: {self.job_ids}")
print(f"📖 Slurm hosts path: {slurm_host_path}")
self.container_id = None
try:
# ensure job is running
for job_id in self.job_ids:
with Loader(f"Waiting for {job_id} to be created"):
while not is_job_running(job_id):
sleep(1)
slumr_log_path = os.path.join(SLURM_LOGS_FOLDER, f"llm-swarm_{job_id}.out")
print(f"📖 Slurm log path: {slumr_log_path}")
# retrieve endpoints
self.endpoints = get_endpoints(slurm_host_path, self.config.instances, self.job_ids)
print(f"Endpoints running properly: {self.endpoints}")
# warm up endpoints
for endpoint in self.endpoints:
test_generation(endpoint)
if len(self.endpoints) == 1:
print(f"🔥 endpoint ready {self.endpoints[0]}")
self.endpoint = self.endpoints[0]
else:
# run a load balancer
with open(self.config.load_balancer_template_path) as f:
# templates/nginx.template.conf
load_balancer_template = f.read()
servers = "\n".join([f"server {endpoint.replace('http://', '')};" for endpoint in self.endpoints])
unused_port = get_unused_port()
load_balancer_template = load_balancer_template.replace(r"{{servers}}", servers)
load_balancer_template = load_balancer_template.replace(r"{{port}}", str(unused_port))
load_balancer_path = os.path.join("slurm", f"{self.filename}_load_balancer.conf")
with open(load_balancer_path, "w") as f:
f.write(load_balancer_template)
load_balance_endpoint = f"http://localhost:{unused_port}"
command = f"sudo docker run -d -p {unused_port}:{unused_port} --network host -v $(pwd)/{load_balancer_path}:/etc/nginx/nginx.conf nginx"
load_balance_endpoint_connected = False
# run docker streaming output while we validate the endpoints
self.container_id = run_command(command)
last_line = 0
while True:
logs = run_command(f"sudo docker logs {self.container_id}")
lines = logs.split("\n")
for line in lines[last_line:]:
print(line)
last_line = len(lines)
if not load_balance_endpoint_connected:
try:
get_session().get(f"{load_balance_endpoint}/health")
print(f"🔥 endpoint ready {load_balance_endpoint}")
load_balance_endpoint_connected = True
self.endpoint = load_balance_endpoint
break
except requests.exceptions.ConnectionError:
sleep(1)
if self.config.inference_engine == "vllm":
self.endpoint = f"{self.endpoint}/generate"
except (KeyboardInterrupt, Exception):
self.cleanup()