def main()

in src/pixparse/app/eval.py [0:0]


def main():
    args = parser.parse_args()
    eval_cfg: EvalCfg = args.eval
    data_cfg: DataCfg = args.data

    device_env = DeviceEnv()
    task, task_cfg = TaskFactory.create_task(task_name=eval_cfg.task_name, task_args=args.task, device_env=device_env, monitor=None)


    random_seed(
        eval_cfg.seed, rank=device_env.global_rank
    )  # Seed variability for eval?
    _logger.info(f"Device env is {device_env}")

    assert (
        eval_cfg.output_dir is not None
    ), f"output_dir is not provided. Stopping eval run."


    if device_env.is_primary():
        log_path = os.path.join(eval_cfg.output_dir, eval_cfg.log_filename)

        # Setup text logger
        setup_logging(log_path)
    monitor = Monitor(
        eval_cfg.experiment,
        output_dir=eval_cfg.output_dir,
        output_enabled=device_env.is_primary(),
    )
    
    # Check if current tasks is external model evaluation
    
    # FIXME defer load checkpoint to task?

    if eval_cfg.task_name not in ["donut_eval_ocr"]:
        checkpoint_path = eval_cfg.checkpoint_path
        eval_cfg = replace(eval_cfg, checkpoint_path=checkpoint_path)

        # FIXME check if path is local or s3?
        if eval_cfg.s3_bucket != "":
            _logger.info("s3 bucket specified. Loading checkpoint from s3.")
            checkpoint = load_checkpoint_from_s3(
                eval_cfg.s3_bucket, eval_cfg.checkpoint_path
            )
        else:
            assert os.path.isfile(
                checkpoint_path
            ), f"Cannot find checkpoint {checkpoint_path}: File not found"

            checkpoint = torch.load(eval_cfg.checkpoint_path)
        if isinstance(checkpoint, OrderedDict):
            state_dict = checkpoint
        else:
            state_dict = checkpoint["model"]
        # Create safe metrics file path

        checkpoint_name = eval_cfg.checkpoint_path.replace("/", "_").replace(".pt", "")
        metrics_file_name = f"{checkpoint_name}-{eval_cfg.dataset_name}-metrics.json"

        # bypass DDP module
        
        eval_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        task.resume_state_dict = eval_state_dict
    else:
        # Get a generic name for external model on chosen dataset
        metrics_file_name = f"{eval_cfg.task_name}-{eval_cfg.dataset_name}-metrics.json"

    eval_cfg.metrics_file_path = os.path.join(eval_cfg.output_dir, metrics_file_name)

    if device_env.is_primary():
        _logger.info(task_cfg)
        _logger.info(eval_cfg)



    loaders = {}
    assert data_cfg.eval is not None, f"data_cfg.eval is not set."

    # FIXME add common functionality for loader selection per task
    loaders["eval"] = create_loader(
        data_cfg.eval,
        is_train=False,
        collate_fn=task.collate_fn,
        image_preprocess=task.image_preprocess_eval,
        anno_preprocess=task.anno_preprocess_eval,
        image_fmt=task_cfg.model.image_encoder.image_fmt,
        world_size=device_env.world_size,
        local_rank=device_env.local_rank,
        create_decoder_pipe=create_image_text_pipe, # TODO abstract away type of decoder needed
        # world_size=device_env.world_size
    )

    task.setup()

    if device_env.is_primary():
        _logger.info(task)

    eval(
        eval_cfg,
        task,
        loaders,
    )

    task.end()