def create_huggingface_api_worker()

in fastchat/serve/huggingface_api_worker.py [0:0]


def create_huggingface_api_worker():
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=21002)
    parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
    parser.add_argument(
        "--controller-address", type=str, default="http://localhost:21001"
    )
    # all model-related parameters are listed in --model-info-file
    parser.add_argument(
        "--model-info-file",
        type=str,
        required=True,
        help="Huggingface API model's info file path",
    )

    parser.add_argument(
        "--limit-worker-concurrency",
        type=int,
        default=5,
        help="Limit the model concurrency to prevent OOM.",
    )
    parser.add_argument("--no-register", action="store_true")
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Overwrite the random seed for each generation.",
    )
    parser.add_argument(
        "--ssl",
        action="store_true",
        required=False,
        default=False,
        help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
    )
    args = parser.parse_args()

    with open(args.model_info_file, "r", encoding="UTF-8") as f:
        model_info = json.load(f)

    logger.info(f"args: {args}")

    model_path_list = []
    api_base_list = []
    token_list = []
    context_length_list = []
    model_names_list = []
    conv_template_list = []

    for m in model_info:
        model_path_list.append(model_info[m]["model_path"])
        api_base_list.append(model_info[m]["api_base"])
        token_list.append(model_info[m]["token"])

        context_length = model_info[m]["context_length"]
        model_names = model_info[m].get("model_names", [m.split("/")[-1]])
        if isinstance(model_names, str):
            model_names = [model_names]
        conv_template = model_info[m].get("conv_template", None)

        context_length_list.append(context_length)
        model_names_list.append(model_names)
        conv_template_list.append(conv_template)

    logger.info(f"Model paths: {model_path_list}")
    logger.info(f"API bases: {api_base_list}")
    logger.info(f"Tokens: {token_list}")
    logger.info(f"Context lengths: {context_length_list}")
    logger.info(f"Model names: {model_names_list}")
    logger.info(f"Conv templates: {conv_template_list}")

    for (
        model_names,
        conv_template,
        model_path,
        api_base,
        token,
        context_length,
    ) in zip(
        model_names_list,
        conv_template_list,
        model_path_list,
        api_base_list,
        token_list,
        context_length_list,
    ):
        m = HuggingfaceApiWorker(
            args.controller_address,
            args.worker_address,
            worker_id,
            model_path,
            api_base,
            token,
            context_length,
            model_names,
            args.limit_worker_concurrency,
            no_register=args.no_register,
            conv_template=conv_template,
            seed=args.seed,
        )
        workers.append(m)
        for name in model_names:
            worker_map[name] = m

    # register all the models
    url = args.controller_address + "/register_worker"
    data = {
        "worker_name": workers[0].worker_addr,
        "check_heart_beat": not args.no_register,
        "worker_status": {
            "model_names": [m for w in workers for m in w.model_names],
            "speed": 1,
            "queue_length": sum([w.get_queue_length() for w in workers]),
        },
    }
    r = requests.post(url, json=data)
    assert r.status_code == 200

    return args, workers