in d2go/runner/lightning_task.py [0:0]
def _reset_dataset_evaluators(self):
"""reset validation dataset evaluator to be run in EVAL_PERIOD steps"""
assert (
not self.trainer._accelerator_connector.distributed_backend
or self.trainer._accelerator_connector.distributed_backend.lower()
in ["ddp", "ddp_cpu"]
), ("Only DDP and DDP_CPU distributed backend are supported")
def _get_inference_dir_name(
base_dir, inference_type, dataset_name, model_tag: ModelTag
):
next_eval_iter = self.trainer.global_step + self.cfg.TEST.EVAL_PERIOD
if self.trainer.global_step == 0:
next_eval_iter -= 1
return os.path.join(
base_dir,
inference_type,
model_tag,
str(next_eval_iter),
dataset_name,
)
@rank_zero_only
def _setup_visualization_evaluator(
evaluator,
dataset_name: str,
model_tag: ModelTag,
) -> None:
logger.info("Adding visualization evaluator ...")
mapper = self.get_mapper(self.cfg, is_train=False)
vis_eval_type = self.get_visualization_evaluator()
# TODO: replace tbx_writter with Lightning's self.logger.experiment
tbx_writter = _get_tbx_writer(get_tensorboard_log_dir(self.cfg.OUTPUT_DIR))
if vis_eval_type is not None:
evaluator._evaluators.append(
vis_eval_type(
self.cfg,
tbx_writter,
mapper,
dataset_name,
train_iter=self.trainer.global_step,
tag_postfix=model_tag,
)
)
for tag, dataset_evaluators in self.dataset_evaluators.items():
dataset_evaluators.clear()
assert self.cfg.OUTPUT_DIR, "Expect output_dir to be specified in config"
for dataset_name in self.cfg.DATASETS.TEST:
# setup evaluator for each dataset
output_folder = _get_inference_dir_name(
self.cfg.OUTPUT_DIR, "inference", dataset_name, tag
)
evaluator = self.get_evaluator(
self.cfg, dataset_name, output_folder=output_folder
)
evaluator.reset()
dataset_evaluators.append(evaluator)
_setup_visualization_evaluator(evaluator, dataset_name, tag)