in src/pixparse/app/eval.py [0:0]
def main():
args = parser.parse_args()
eval_cfg: EvalCfg = args.eval
data_cfg: DataCfg = args.data
device_env = DeviceEnv()
task, task_cfg = TaskFactory.create_task(task_name=eval_cfg.task_name, task_args=args.task, device_env=device_env, monitor=None)
random_seed(
eval_cfg.seed, rank=device_env.global_rank
) # Seed variability for eval?
_logger.info(f"Device env is {device_env}")
assert (
eval_cfg.output_dir is not None
), f"output_dir is not provided. Stopping eval run."
if device_env.is_primary():
log_path = os.path.join(eval_cfg.output_dir, eval_cfg.log_filename)
# Setup text logger
setup_logging(log_path)
monitor = Monitor(
eval_cfg.experiment,
output_dir=eval_cfg.output_dir,
output_enabled=device_env.is_primary(),
)
# Check if current tasks is external model evaluation
# FIXME defer load checkpoint to task?
if eval_cfg.task_name not in ["donut_eval_ocr"]:
checkpoint_path = eval_cfg.checkpoint_path
eval_cfg = replace(eval_cfg, checkpoint_path=checkpoint_path)
# FIXME check if path is local or s3?
if eval_cfg.s3_bucket != "":
_logger.info("s3 bucket specified. Loading checkpoint from s3.")
checkpoint = load_checkpoint_from_s3(
eval_cfg.s3_bucket, eval_cfg.checkpoint_path
)
else:
assert os.path.isfile(
checkpoint_path
), f"Cannot find checkpoint {checkpoint_path}: File not found"
checkpoint = torch.load(eval_cfg.checkpoint_path)
if isinstance(checkpoint, OrderedDict):
state_dict = checkpoint
else:
state_dict = checkpoint["model"]
# Create safe metrics file path
checkpoint_name = eval_cfg.checkpoint_path.replace("/", "_").replace(".pt", "")
metrics_file_name = f"{checkpoint_name}-{eval_cfg.dataset_name}-metrics.json"
# bypass DDP module
eval_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
task.resume_state_dict = eval_state_dict
else:
# Get a generic name for external model on chosen dataset
metrics_file_name = f"{eval_cfg.task_name}-{eval_cfg.dataset_name}-metrics.json"
eval_cfg.metrics_file_path = os.path.join(eval_cfg.output_dir, metrics_file_name)
if device_env.is_primary():
_logger.info(task_cfg)
_logger.info(eval_cfg)
loaders = {}
assert data_cfg.eval is not None, f"data_cfg.eval is not set."
# FIXME add common functionality for loader selection per task
loaders["eval"] = create_loader(
data_cfg.eval,
is_train=False,
collate_fn=task.collate_fn,
image_preprocess=task.image_preprocess_eval,
anno_preprocess=task.anno_preprocess_eval,
image_fmt=task_cfg.model.image_encoder.image_fmt,
world_size=device_env.world_size,
local_rank=device_env.local_rank,
create_decoder_pipe=create_image_text_pipe, # TODO abstract away type of decoder needed
# world_size=device_env.world_size
)
task.setup()
if device_env.is_primary():
_logger.info(task)
eval(
eval_cfg,
task,
loaders,
)
task.end()