in scripts/train_instance_seg.py [0:0]
def make_model(config, num_thing, num_stuff):
body_config = config["body"]
fpn_config = config["fpn"]
rpn_config = config["rpn"]
roi_config = config["roi"]
classes = {"total": num_thing + num_stuff, "stuff": num_stuff, "thing": num_thing}
# BN + activation
norm_act_static, norm_act_dynamic = norm_act_from_config(body_config)
# Create backbone
log_debug("Creating backbone model %s", body_config["body"])
body_fn = models.__dict__["net_" + body_config["body"]]
body_params = body_config.getstruct("body_params") if body_config.get("body_params") else {}
body = body_fn(norm_act=norm_act_static, **body_params)
if body_config.get("weights"):
body.load_state_dict(torch.load(body_config["weights"], map_location="cpu"))
# Freeze parameters
for n, m in body.named_modules():
for mod_id in range(1, body_config.getint("num_frozen") + 1):
if ("mod%d" % mod_id) in n:
freeze_params(m)
body_channels = body_config.getstruct("out_channels")
# Create FPN
fpn_inputs = fpn_config.getstruct("inputs")
fpn = FPN([body_channels[inp] for inp in fpn_inputs],
fpn_config.getint("out_channels"),
fpn_config.getint("extra_scales"),
norm_act_static,
fpn_config["interpolation"])
body = FPNBody(body, fpn, fpn_inputs)
# Create RPN
proposal_generator = ProposalGenerator(rpn_config.getfloat("nms_threshold"),
rpn_config.getint("num_pre_nms_train"),
rpn_config.getint("num_post_nms_train"),
rpn_config.getint("num_pre_nms_val"),
rpn_config.getint("num_post_nms_val"),
rpn_config.getint("min_size"))
anchor_matcher = AnchorMatcher(rpn_config.getint("num_samples"),
rpn_config.getfloat("pos_ratio"),
rpn_config.getfloat("pos_threshold"),
rpn_config.getfloat("neg_threshold"),
rpn_config.getfloat("void_threshold"))
rpn_loss = RPNLoss(rpn_config.getfloat("sigma"))
rpn_algo = RPNAlgoFPN(
proposal_generator, anchor_matcher, rpn_loss,
rpn_config.getint("anchor_scale"), rpn_config.getstruct("anchor_ratios"),
fpn_config.getstruct("out_strides"), rpn_config.getint("fpn_min_level"), rpn_config.getint("fpn_levels"))
rpn_head = RPNHead(
fpn_config.getint("out_channels"), len(rpn_config.getstruct("anchor_ratios")), 1,
rpn_config.getint("hidden_channels"), norm_act_dynamic)
# Create instance segmentation network
bbx_prediction_generator = BbxPredictionGenerator(roi_config.getfloat("nms_threshold"),
roi_config.getfloat("score_threshold"),
roi_config.getint("max_predictions"))
msk_prediction_generator = MskPredictionGenerator()
roi_size = roi_config.getstruct("roi_size")
proposal_matcher = ProposalMatcher(classes,
roi_config.getint("num_samples"),
roi_config.getfloat("pos_ratio"),
roi_config.getfloat("pos_threshold"),
roi_config.getfloat("neg_threshold_hi"),
roi_config.getfloat("neg_threshold_lo"),
roi_config.getfloat("void_threshold"))
bbx_loss = DetectionLoss(roi_config.getfloat("sigma"))
msk_loss = InstanceSegLoss()
lbl_roi_size = tuple(s * 2 for s in roi_size)
roi_algo = InstanceSegAlgoFPN(
bbx_prediction_generator, msk_prediction_generator, proposal_matcher, bbx_loss, msk_loss, classes,
roi_config.getstruct("bbx_reg_weights"), roi_config.getint("fpn_canonical_scale"),
roi_config.getint("fpn_canonical_level"), roi_size, roi_config.getint("fpn_min_level"),
roi_config.getint("fpn_levels"), lbl_roi_size, roi_config.getboolean("void_is_background"))
roi_head = FPNMaskHead(fpn_config.getint("out_channels"), classes, roi_size, norm_act=norm_act_dynamic)
# Create final network
return InstanceSegNet(body, rpn_head, roi_head, rpn_algo, roi_algo, classes)