def _train_impl()

in vision/amazon-sagemaker-pytorch-detectron2/container_training/sku-110k/training.py [0:0]


def _train_impl(args) -> None:
    r"""Training implementation executes the following steps:

        * Register the dataset to Detectron2 catalog
        * Create the configuration node for training
        * Launch training
        * Serialize the training configuration to a JSON file as it is required for prediction
    """

    dataset = DataSetMeta(name=args.dataset_name, classes=args.classes)

    for ds_type in (
        ("training", "validation", "test")
        if args.evaluation_type
        else ("training", "validation",)
    ):
        if not Path(args.annotation_channel) / f"{ds_type}.manifest":
            err_msg = f"{ds_type} dataset annotations not found"
            LOGGER.error(err_msg)
            raise FileNotFoundError(err_msg)

    channel_to_ds = {
        "training": (
            args.training_channel,
            f"{args.annotation_channel}/training.manifest",
        ),
        "validation": (
            args.validation_channel,
            f"{args.annotation_channel}/validation.manifest",
        ),
    }
    if args.evaluation_type:
        channel_to_ds["test"] = (
            args.test_channel,
            f"{args.annotation_channel}/test.manifest",
        )

    register_dataset(
        metadata=dataset, label_name=args.label_name, channel_to_dataset=channel_to_ds,
    )

    cfg = _config_training(args)

    cfg.setdefault("VAL_LOG_PERIOD", args.log_period)

    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=False)

    if cfg.MODEL.DEVICE != "cuda":
        err = RuntimeError("A CUDA device is required to launch training")
        LOGGER.error(err)
        raise err
    trainer.train()

    # If in the master process: save config and run COCO evaluation on test set
    if args.current_host == args.hosts[0]:
        with open(f"{cfg.OUTPUT_DIR}/config.json", "w") as fid:
            json.dump(cfg, fid, indent=2)

        if args.evaluation_type:
            LOGGER.info(f"Running {args.evaluation_type} evaluation on the test set")
            evaluator = D2CocoEvaluator(
                dataset_name=f"{dataset.name}_test",
                tasks=("bbox",),
                distributed=len(args.hosts)==1 and args.num_gpus > 1,
                output_dir=f"{cfg.OUTPUT_DIR}/eval",
                use_fast_impl=args.evaluation_type == "fast",
                nb_max_preds=cfg.TEST.DETECTIONS_PER_IMAGE,
            )
            cfg.DATASETS.TEST = (f"{args.dataset_name}_test",)
            model = Trainer.build_model(cfg)
            DetectionCheckpointer(model).load(f"{cfg.OUTPUT_DIR}/model_final.pth")
            Trainer.test(cfg, model, evaluator)
        else:
            LOGGER.info("Evaluation on the test set skipped")