def main()

in ultravox/inference/run_vllm_inference.py [0:0]


def main():
    args = simple_parsing.parse(InferenceArgs)

    model_path = prefetch_weights.download_weights(
        [], args.model, include_models_from_load_dir=True
    )

    args.model = model_path or args.model

    if args.push_to_hub:
        push_to_hub.main(
            push_to_hub.UploadToHubArgs(
                model=args.model,
                hf_upload_model=args.hf_upload_model,
                verify=False,
                device="cpu",
            )
        )
        args.model = args.hf_upload_model

    run = wandb.init(
        project=os.getenv("WANDB_PROJECT", "ultravox"),
        config=dataclasses.asdict(args),
        name=args.exp_name,
        tags=["eval", "vllm"],
        dir="logs",
    )

    log_dir = os.path.join(run.dir, "oaieval")

    vllm_process = start_vllm(args)

    try:
        wait_for_vllm_to_start()
        metrics_df = run_oaievalset(args, log_dir)
        run.log({"eval": wandb.Table(data=metrics_df)})
        run.log({x["eval"]: x["score"] for x in metrics_df.iloc})

    finally:
        # Make sure the VLLM server is stopped
        vllm_process.terminate()
        try:
            vllm_process.wait(10)
        except:
            vllm_process.kill()
            vllm_process.wait(2)