in vision/amazon-sagemaker-pytorch-detectron2/container_training/sku-110k/training.py [0:0]
def _train_impl(args) -> None:
r"""Training implementation executes the following steps:
* Register the dataset to Detectron2 catalog
* Create the configuration node for training
* Launch training
* Serialize the training configuration to a JSON file as it is required for prediction
"""
dataset = DataSetMeta(name=args.dataset_name, classes=args.classes)
for ds_type in (
("training", "validation", "test")
if args.evaluation_type
else ("training", "validation",)
):
if not Path(args.annotation_channel) / f"{ds_type}.manifest":
err_msg = f"{ds_type} dataset annotations not found"
LOGGER.error(err_msg)
raise FileNotFoundError(err_msg)
channel_to_ds = {
"training": (
args.training_channel,
f"{args.annotation_channel}/training.manifest",
),
"validation": (
args.validation_channel,
f"{args.annotation_channel}/validation.manifest",
),
}
if args.evaluation_type:
channel_to_ds["test"] = (
args.test_channel,
f"{args.annotation_channel}/test.manifest",
)
register_dataset(
metadata=dataset, label_name=args.label_name, channel_to_dataset=channel_to_ds,
)
cfg = _config_training(args)
cfg.setdefault("VAL_LOG_PERIOD", args.log_period)
trainer = Trainer(cfg)
trainer.resume_or_load(resume=False)
if cfg.MODEL.DEVICE != "cuda":
err = RuntimeError("A CUDA device is required to launch training")
LOGGER.error(err)
raise err
trainer.train()
# If in the master process: save config and run COCO evaluation on test set
if args.current_host == args.hosts[0]:
with open(f"{cfg.OUTPUT_DIR}/config.json", "w") as fid:
json.dump(cfg, fid, indent=2)
if args.evaluation_type:
LOGGER.info(f"Running {args.evaluation_type} evaluation on the test set")
evaluator = D2CocoEvaluator(
dataset_name=f"{dataset.name}_test",
tasks=("bbox",),
distributed=len(args.hosts)==1 and args.num_gpus > 1,
output_dir=f"{cfg.OUTPUT_DIR}/eval",
use_fast_impl=args.evaluation_type == "fast",
nb_max_preds=cfg.TEST.DETECTIONS_PER_IMAGE,
)
cfg.DATASETS.TEST = (f"{args.dataset_name}_test",)
model = Trainer.build_model(cfg)
DetectionCheckpointer(model).load(f"{cfg.OUTPUT_DIR}/model_final.pth")
Trainer.test(cfg, model, evaluator)
else:
LOGGER.info("Evaluation on the test set skipped")