def start()

in llm_swarm/__init__.py [0:0]


    def start(self):
        # if debug endpoint is provided, use it as is
        if self.config.debug_endpoint:
            self.endpoint = self.config.debug_endpoint
            if self.config.inference_engine == "vllm":
                self.endpoint = f"{self.config.debug_endpoint}/generate"
            if self.config.debug_endpoint.startswith("https://api-inference.huggingface.co/"):
                self.suggested_max_parallel_requests = 40
            else:
                self.suggested_max_parallel_requests = self.config.per_instance_max_parallel_requests * self.config.instances
            return

        self.suggested_max_parallel_requests = self.config.per_instance_max_parallel_requests * self.config.instances
        with open(self.config.slurm_template_path) as f:
            slurm_template = f.read()

        # customize slurm template
        self.filename = f"{self.config.inference_engine}_{int(time.time())}"
        slurm_path = os.path.join("slurm", f"{self.filename}_{self.config.inference_engine}.slurm")
        slurm_host_path = os.path.join("slurm", f"{self.filename}_host_{self.config.inference_engine}.txt")
        slurm_template = slurm_template.replace(r"{{slurm_hosts_path}}", slurm_host_path)
        slurm_template = slurm_template.replace(r"{{model}}", self.config.model)
        slurm_template = slurm_template.replace(r"{{revision}}", self.config.revision)
        slurm_template = slurm_template.replace(r"{{gpus}}", str(self.config.gpus))
        slurm_template = slurm_template.replace(r"{{model_max_length}}", str(min(self.tokenizer.model_max_length, 32768)))
        slurm_template = slurm_template.replace(r"{{model_input_length}}", str(min(self.tokenizer.model_max_length - 100, 32768 - 100))) # `model_input_length` needs to be smaller than `model_max_length`
        with open(slurm_path, "w") as f:
            f.write(slurm_template)

        # start inference instances
        self.job_ids = [run_command(f"sbatch --parsable {slurm_path}") for _ in range(self.config.instances)]
        print(f"Slurm Job ID: {self.job_ids}")
        print(f"📖 Slurm hosts path: {slurm_host_path}")

        self.container_id = None
        try:
            # ensure job is running
            for job_id in self.job_ids:
                with Loader(f"Waiting for {job_id} to be created"):
                    while not is_job_running(job_id):
                        sleep(1)
                slumr_log_path = os.path.join(SLURM_LOGS_FOLDER, f"llm-swarm_{job_id}.out")
                print(f"📖 Slurm log path: {slumr_log_path}")
            # retrieve endpoints
            self.endpoints = get_endpoints(slurm_host_path, self.config.instances, self.job_ids)
            print(f"Endpoints running properly: {self.endpoints}")
            # warm up endpoints
            for endpoint in self.endpoints:
                test_generation(endpoint)

            if len(self.endpoints) == 1:
                print(f"🔥 endpoint ready {self.endpoints[0]}")
                self.endpoint = self.endpoints[0]
            else:
                # run a load balancer
                with open(self.config.load_balancer_template_path) as f:
                    # templates/nginx.template.conf
                    load_balancer_template = f.read()
                servers = "\n".join([f"server {endpoint.replace('http://', '')};" for endpoint in self.endpoints])
                unused_port = get_unused_port()
                load_balancer_template = load_balancer_template.replace(r"{{servers}}", servers)
                load_balancer_template = load_balancer_template.replace(r"{{port}}", str(unused_port))
                load_balancer_path = os.path.join("slurm", f"{self.filename}_load_balancer.conf")
                with open(load_balancer_path, "w") as f:
                    f.write(load_balancer_template)
                load_balance_endpoint = f"http://localhost:{unused_port}"
                command = f"sudo docker run -d -p {unused_port}:{unused_port} --network host -v $(pwd)/{load_balancer_path}:/etc/nginx/nginx.conf nginx"
                load_balance_endpoint_connected = False

                # run docker streaming output while we validate the endpoints
                self.container_id = run_command(command)
                last_line = 0
                while True:
                    logs = run_command(f"sudo docker logs {self.container_id}")
                    lines = logs.split("\n")
                    for line in lines[last_line:]:
                        print(line)
                    last_line = len(lines)

                    if not load_balance_endpoint_connected:
                        try:
                            get_session().get(f"{load_balance_endpoint}/health")
                            print(f"🔥 endpoint ready {load_balance_endpoint}")
                            load_balance_endpoint_connected = True
                            self.endpoint = load_balance_endpoint
                            break
                        except requests.exceptions.ConnectionError:
                            sleep(1)
            if self.config.inference_engine == "vllm":
                self.endpoint = f"{self.endpoint}/generate"
        except (KeyboardInterrupt, Exception):
            self.cleanup()