in src/sagemaker_defect_detection/detector.py [0:0]
def setup(self, stage) -> None:
if self.train_rpn: # step 2
self.mfn = load_checkpoint(
Classification(self.backbone, self.num_classes - 1).mfn, self.pretrained_mfn_ckpt, "model.mfn"
)
self.rpn = RPN()
elif self.train_roi: # step 3
self.mfn = load_checkpoint(
Classification(self.backbone, self.num_classes - 1).mfn, self.pretrained_rpn_ckpt, prefix="mfn"
)
freeze(self.mfn)
self.rpn = load_checkpoint(RPN(), self.pretrained_rpn_ckpt, prefix="rpn")
freeze(self.rpn)
self.roi = RoI(self.num_classes)
elif self.finetune_rpn: # step 4 or extra finetune rpn
if self.finetuned_rpn_ckpt and self.finetuned_roi_ckpt: # extra finetune rpn
self.mfn = load_checkpoint(
Classification(self.backbone, self.num_classes - 1).mfn, self.finetuned_rpn_ckpt, prefix="mfn"
)
freeze(self.mfn)
self.rpn = load_checkpoint(RPN(), self.finetuned_rpn_ckpt, prefix="rpn")
self.roi = load_checkpoint(RoI(self.num_classes), self.finetuned_roi_ckpt, prefix="roi")
freeze(self.roi)
self.model = Detection(self.mfn, self.rpn, self.roi)
else:
self.mfn = load_checkpoint(
Classification(self.backbone, self.num_classes - 1).mfn, self.pretrained_rpn_ckpt, prefix="mfn"
)
freeze(self.mfn)
self.rpn = load_checkpoint(RPN(), self.pretrained_rpn_ckpt, prefix="rpn")
self.roi = load_checkpoint(RoI(self.num_classes), self.pretrained_roi_ckpt, prefix="roi")
freeze(self.roi)
self.model = Detection(self.mfn, self.rpn, self.roi)
elif self.finetune_roi: # step 5 or extra finetune roi
if self.finetuned_rpn_ckpt and self.finetuned_roi_ckpt: # extra finetune roi
self.mfn = load_checkpoint(
Classification(self.backbone, self.num_classes - 1).mfn, self.finetuned_rpn_ckpt, prefix="mfn"
)
freeze(self.mfn)
self.rpn = load_checkpoint(RPN(), self.finetuned_rpn_ckpt, prefix="rpn")
freeze(self.rpn)
self.roi = load_checkpoint(RoI(self.num_classes), self.finetuned_roi_ckpt, prefix="roi")
self.model = Detection(self.mfn, self.rpn, self.roi)
else:
self.mfn = load_checkpoint(
Classification(self.backbone, self.num_classes - 1).mfn, self.finetuned_rpn_ckpt, prefix="mfn"
)
freeze(self.mfn)
self.rpn = load_checkpoint(RPN(), self.finetuned_rpn_ckpt, prefix="rpn")
freeze(self.rpn)
self.roi = load_checkpoint(RoI(self.num_classes), self.pretrained_roi_ckpt, prefix="roi")
self.model = Detection(self.mfn, self.rpn, self.roi)
else: # step 6: final/joint model
load_checkpoint_fn = load_checkpoint
if self.finetuned_roi_ckpt is not None:
ckpt_path = self.finetuned_rpn_ckpt
elif self.resume_sagemaker_from_checkpoint is not None:
ckpt_path = self.resume_sagemaker_from_checkpoint
else:
ckpt_path = None
# ignore load_checkpoint
load_checkpoint_fn = lambda *args: args[0]
self.mfn = load_checkpoint_fn(Classification(self.backbone, self.num_classes - 1).mfn, ckpt_path, "mfn")
self.rpn = load_checkpoint_fn(RPN(), ckpt_path, "rpn")
self.roi = load_checkpoint_fn(RoI(self.num_classes), ckpt_path, "roi")
self.model = Detection(self.mfn, self.rpn, self.roi)
return