def main()

in ultravox/evaluation/eval.py [0:0]


def main(override_sys_args: Optional[List[str]] = None):
    monkey_patches.apply_all_patches()

    config = simple_parsing.parse(
        EvalConfig, add_config_path_arg=True, args=override_sys_args
    )

    world_size = device_helpers.get_world_size()
    local_rank = device_helpers.get_local_rank()

    if world_size > 1:
        # use gloo instead of nccl as the gathering opration is on cpu; need to double check if nccl is supported previously.
        dist.init_process_group(backend="gloo")

    if local_rank == 0:
        if config.output_dir:
            config.output_dir.mkdir(parents=True, exist_ok=True)
        if "wandb" in config.report_logs_to:
            wandb.init(
                project=os.getenv("WANDB_PROJECT", "ultravox"),
                config=dataclasses.asdict(config),
                name=config.exp_name,
                dir="runs",
                save_code=True,
            )

    with ddp_utils.run_on_master_first(local_rank == 0):
        inference = ultravox_infer.UltravoxInference(
            config.model,
            device=(
                f"{config.device}:{local_rank}" if world_size > 1 else config.device
            ),
            data_type=config.data_type,
        )

    metrics, output_files = eval_datasets(
        inference,
        config.get_eval_sets(),
        config.eval_dataset_args,
        config.eval_batch_size,
        config.eval_max_tokens,
        config.eval_temperature,
        config.output_dir,
    )

    if local_rank == 0:
        print_results(metrics, output_files)

        if wandb.run:
            wandb.run.finish()

    if world_size > 1:
        dist.destroy_process_group()