def _config_training()

in vision/amazon-sagemaker-pytorch-detectron2/container_training/sku-110k/training.py [0:0]


def _config_training(args: argparse.Namespace) -> CfgNode:
    r"""Create a configuration node from the script arguments.

    In this application we consider object detection use case only. We finetune object detection
    networks trained on COCO dataset to a custom use case

    Parameters
    ----------
    args : argparse.Namespace
        training script arguments, see :py:meth:`_parse_args()`

    Returns
    -------
    CfgNode
        configuration that is used by Detectron2 to train a model

    Raises:
        RuntimeError: if the combination of `model_type`, `backbone`, `lr_schedule` is not valid.
            Please refer to Detectron2 model zoo for valid options.
    """
    cfg = get_cfg()
    pretrained_model = (
        f"COCO-Detection/{args.model_type}_{args.backbone}_{args.lr_schedule}x.yaml"
    )
    LOGGER.info(f"Loooking for the pretrained model {pretrained_model}...")
    try:
        cfg.merge_from_file(model_zoo.get_config_file(pretrained_model))
    except RuntimeError as err:
        LOGGER.error(f"{err}: check model backbone and lr schedule combination")
        raise
    cfg.DATASETS.TRAIN = (f"{args.dataset_name}_training",)
    cfg.DATASETS.TEST = (f"{args.dataset_name}_validation",)
    cfg.DATALOADER.NUM_WORKERS = args.num_workers
    # Let training initialize from model zoo
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(pretrained_model)
    LOGGER.info(f"{pretrained_model} correctly loaded")

    cfg.SOLVER.CHECKPOINT_PERIOD = 20000
    cfg.SOLVER.BASE_LR = args.lr
    cfg.SOLVER.MAX_ITER = args.num_iter
    cfg.SOLVER.IMS_PER_BATCH = args.batch_size
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = args.num_rpn
    if args.model_type == "faster_rcnn":
        cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(args.classes)
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.pred_thr
        cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = args.nms_thr
        cfg.MODEL.RPN.BBOX_REG_LOSS_TYPE = args.reg_loss_type
        cfg.MODEL.RPN.BBOX_REG_LOSS_WEIGHT = args.bbox_reg_loss_weight
        cfg.MODEL.RPN.POSITIVE_FRACTION = args.bbox_rpn_pos_fraction
        cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION = args.bbox_head_pos_fraction
    elif args.model_type == "retinanet":
        cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.pred_thr
        cfg.MODEL.RETINANET.NMS_THRESH_TEST = args.nms_thr
        cfg.MODEL.RETINANET.NUM_CLASSES = len(args.classes)
        cfg.MODEL.RETINANET.BBOX_REG_LOSS_TYPE = args.reg_loss_type
        cfg.MODEL.RETINANET.FOCAL_LOSS_GAMMA = args.focal_loss_gamma
        cfg.MODEL.RETINANET.FOCAL_LOSS_ALPHA = args.focal_loss_alpha
    else:
        assert False, f"Add implementation for model {args.model_type}"
    cfg.MODEL.DEVICE = "cuda" if args.num_gpus else "cpu"

    cfg.TEST.DETECTIONS_PER_IMAGE = args.det_per_img

    cfg.OUTPUT_DIR = args.model_dir
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    return cfg