def entrypoint()

in optimum_benchmark/launchers/torchrun/launcher.py [0:0]


def entrypoint(worker: Callable[..., BenchmarkReport], worker_args: List[Any], logger: Logger):
    rank = int(os.environ.get("RANK", "0"))
    log_level = os.environ.get("LOG_LEVEL", "INFO")
    log_to_file = os.environ.get("LOG_TO_FILE", "1") == "1"
    log_all_ranks = os.environ.get("LOG_ALL_RANKS", "0") == "1"

    if log_all_ranks or rank == 0:
        setup_logging(level=log_level, to_file=log_to_file, prefix=f"RANK-PROCESS-{rank}")
    else:
        setup_logging(level="ERROR", to_file=log_to_file, prefix=f"RANK-PROCESS-{rank}")

    if torch.cuda.is_available():
        logger.info(f"\t+ Setting torch.distributed cuda device to {rank}")
        device = torch.device("cuda", rank)
        torch.cuda.set_device(device)

    backend = None
    if hasattr(torch.mps, "is_available") and torch.mps.is_available():
        backend = "gloo"

    logger.info("\t+ Initializing torch.distributed process group")
    torch.distributed.init_process_group(backend=backend)

    try:
        report = worker(*worker_args)
    except Exception:
        logger.error("\t+ Benchmark failed with an exception")
        output = traceback.format_exc()
    else:
        logger.info("\t+ Benchmark completed successfully")
        output = report.to_dict()
    finally:
        logger.info("\t+ Destroying torch.distributed process group")
        torch.distributed.destroy_process_group()
        logger.info("\t+ Exiting rank process")
        return output