in vision/amazon-sagemaker-pytorch-detectron2/container_training/sku-110k/training.py [0:0]
def _config_training(args: argparse.Namespace) -> CfgNode:
r"""Create a configuration node from the script arguments.
In this application we consider object detection use case only. We finetune object detection
networks trained on COCO dataset to a custom use case
Parameters
----------
args : argparse.Namespace
training script arguments, see :py:meth:`_parse_args()`
Returns
-------
CfgNode
configuration that is used by Detectron2 to train a model
Raises:
RuntimeError: if the combination of `model_type`, `backbone`, `lr_schedule` is not valid.
Please refer to Detectron2 model zoo for valid options.
"""
cfg = get_cfg()
pretrained_model = (
f"COCO-Detection/{args.model_type}_{args.backbone}_{args.lr_schedule}x.yaml"
)
LOGGER.info(f"Loooking for the pretrained model {pretrained_model}...")
try:
cfg.merge_from_file(model_zoo.get_config_file(pretrained_model))
except RuntimeError as err:
LOGGER.error(f"{err}: check model backbone and lr schedule combination")
raise
cfg.DATASETS.TRAIN = (f"{args.dataset_name}_training",)
cfg.DATASETS.TEST = (f"{args.dataset_name}_validation",)
cfg.DATALOADER.NUM_WORKERS = args.num_workers
# Let training initialize from model zoo
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(pretrained_model)
LOGGER.info(f"{pretrained_model} correctly loaded")
cfg.SOLVER.CHECKPOINT_PERIOD = 20000
cfg.SOLVER.BASE_LR = args.lr
cfg.SOLVER.MAX_ITER = args.num_iter
cfg.SOLVER.IMS_PER_BATCH = args.batch_size
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = args.num_rpn
if args.model_type == "faster_rcnn":
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(args.classes)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.pred_thr
cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = args.nms_thr
cfg.MODEL.RPN.BBOX_REG_LOSS_TYPE = args.reg_loss_type
cfg.MODEL.RPN.BBOX_REG_LOSS_WEIGHT = args.bbox_reg_loss_weight
cfg.MODEL.RPN.POSITIVE_FRACTION = args.bbox_rpn_pos_fraction
cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION = args.bbox_head_pos_fraction
elif args.model_type == "retinanet":
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.pred_thr
cfg.MODEL.RETINANET.NMS_THRESH_TEST = args.nms_thr
cfg.MODEL.RETINANET.NUM_CLASSES = len(args.classes)
cfg.MODEL.RETINANET.BBOX_REG_LOSS_TYPE = args.reg_loss_type
cfg.MODEL.RETINANET.FOCAL_LOSS_GAMMA = args.focal_loss_gamma
cfg.MODEL.RETINANET.FOCAL_LOSS_ALPHA = args.focal_loss_alpha
else:
assert False, f"Add implementation for model {args.model_type}"
cfg.MODEL.DEVICE = "cuda" if args.num_gpus else "cpu"
cfg.TEST.DETECTIONS_PER_IMAGE = args.det_per_img
cfg.OUTPUT_DIR = args.model_dir
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
return cfg