in ultravox/evaluation/eval.py [0:0]
def main(override_sys_args: Optional[List[str]] = None):
monkey_patches.apply_all_patches()
config = simple_parsing.parse(
EvalConfig, add_config_path_arg=True, args=override_sys_args
)
world_size = device_helpers.get_world_size()
local_rank = device_helpers.get_local_rank()
if world_size > 1:
# use gloo instead of nccl as the gathering opration is on cpu; need to double check if nccl is supported previously.
dist.init_process_group(backend="gloo")
if local_rank == 0:
if config.output_dir:
config.output_dir.mkdir(parents=True, exist_ok=True)
if "wandb" in config.report_logs_to:
wandb.init(
project=os.getenv("WANDB_PROJECT", "ultravox"),
config=dataclasses.asdict(config),
name=config.exp_name,
dir="runs",
save_code=True,
)
with ddp_utils.run_on_master_first(local_rank == 0):
inference = ultravox_infer.UltravoxInference(
config.model,
device=(
f"{config.device}:{local_rank}" if world_size > 1 else config.device
),
data_type=config.data_type,
)
metrics, output_files = eval_datasets(
inference,
config.get_eval_sets(),
config.eval_dataset_args,
config.eval_batch_size,
config.eval_max_tokens,
config.eval_temperature,
config.output_dir,
)
if local_rank == 0:
print_results(metrics, output_files)
if wandb.run:
wandb.run.finish()
if world_size > 1:
dist.destroy_process_group()