in fastchat/serve/multi_model_worker.py [0:0]
def create_multi_model_worker():
# Note: Ensure we resolve arg conflicts. We let `add_model_args` add MOST
# of the model args but we'll override one to have an append action that
# supports multiple values.
parser = argparse.ArgumentParser(conflict_handler="resolve")
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
add_model_args(parser)
# Override the model path to be repeated and align it with model names.
parser.add_argument(
"--model-path",
type=str,
default=[],
action="append",
help="One or more paths to model weights to load. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument(
"--model-names",
type=lambda s: s.split(","),
action="append",
help="One or more model names. Values must be aligned with `--model-path` values.",
)
parser.add_argument(
"--conv-template",
type=str,
default=None,
action="append",
help="Conversation prompt template. Values must be aligned with `--model-path` values. If only one value is provided, it will be repeated for all models.",
)
parser.add_argument("--limit-worker-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
parser.add_argument(
"--ssl",
action="store_true",
required=False,
default=False,
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
)
args = parser.parse_args()
logger.info(f"args: {args}")
if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
if args.enable_exllama:
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
cache_8bit=args.exllama_cache_8bit,
)
else:
exllama_config = None
if args.enable_xft:
xft_config = XftConfig(
max_seq_len=args.xft_max_seq_len,
data_type=args.xft_dtype,
)
if args.device != "cpu":
print("xFasterTransformer now is only support CPUs. Reset device to CPU")
args.device = "cpu"
else:
xft_config = None
if args.model_names is None:
args.model_names = [[x.split("/")[-1]] for x in args.model_path]
if args.conv_template is None:
args.conv_template = [None] * len(args.model_path)
elif len(args.conv_template) == 1: # Repeat the same template
args.conv_template = args.conv_template * len(args.model_path)
# Launch all workers
workers = []
for conv_template, model_path, model_names in zip(
args.conv_template, args.model_path, args.model_names
):
w = ModelWorker(
args.controller_address,
args.worker_address,
worker_id,
model_path,
model_names,
args.limit_worker_concurrency,
args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
exllama_config=exllama_config,
xft_config=xft_config,
stream_interval=args.stream_interval,
conv_template=conv_template,
)
workers.append(w)
for model_name in model_names:
worker_map[model_name] = w
# Register all models
url = args.controller_address + "/register_worker"
data = {
"worker_name": workers[0].worker_addr,
"check_heart_beat": not args.no_register,
"worker_status": {
"model_names": [m for w in workers for m in w.model_names],
"speed": 1,
"queue_length": sum([w.get_queue_length() for w in workers]),
},
}
r = requests.post(url, json=data)
assert r.status_code == 200
return args, workers