def setup()

in src/sagemaker_defect_detection/detector.py [0:0]


    def setup(self, stage) -> None:
        if self.train_rpn:  # step 2
            self.mfn = load_checkpoint(
                Classification(self.backbone, self.num_classes - 1).mfn, self.pretrained_mfn_ckpt, "model.mfn"
            )
            self.rpn = RPN()

        elif self.train_roi:  # step 3
            self.mfn = load_checkpoint(
                Classification(self.backbone, self.num_classes - 1).mfn, self.pretrained_rpn_ckpt, prefix="mfn"
            )
            freeze(self.mfn)

            self.rpn = load_checkpoint(RPN(), self.pretrained_rpn_ckpt, prefix="rpn")
            freeze(self.rpn)

            self.roi = RoI(self.num_classes)

        elif self.finetune_rpn:  # step 4 or extra finetune rpn
            if self.finetuned_rpn_ckpt and self.finetuned_roi_ckpt:  # extra finetune rpn
                self.mfn = load_checkpoint(
                    Classification(self.backbone, self.num_classes - 1).mfn, self.finetuned_rpn_ckpt, prefix="mfn"
                )
                freeze(self.mfn)

                self.rpn = load_checkpoint(RPN(), self.finetuned_rpn_ckpt, prefix="rpn")

                self.roi = load_checkpoint(RoI(self.num_classes), self.finetuned_roi_ckpt, prefix="roi")
                freeze(self.roi)

                self.model = Detection(self.mfn, self.rpn, self.roi)

            else:
                self.mfn = load_checkpoint(
                    Classification(self.backbone, self.num_classes - 1).mfn, self.pretrained_rpn_ckpt, prefix="mfn"
                )
                freeze(self.mfn)

                self.rpn = load_checkpoint(RPN(), self.pretrained_rpn_ckpt, prefix="rpn")

                self.roi = load_checkpoint(RoI(self.num_classes), self.pretrained_roi_ckpt, prefix="roi")
                freeze(self.roi)

                self.model = Detection(self.mfn, self.rpn, self.roi)

        elif self.finetune_roi:  # step 5 or extra finetune roi
            if self.finetuned_rpn_ckpt and self.finetuned_roi_ckpt:  # extra finetune roi
                self.mfn = load_checkpoint(
                    Classification(self.backbone, self.num_classes - 1).mfn, self.finetuned_rpn_ckpt, prefix="mfn"
                )
                freeze(self.mfn)

                self.rpn = load_checkpoint(RPN(), self.finetuned_rpn_ckpt, prefix="rpn")
                freeze(self.rpn)

                self.roi = load_checkpoint(RoI(self.num_classes), self.finetuned_roi_ckpt, prefix="roi")

                self.model = Detection(self.mfn, self.rpn, self.roi)

            else:
                self.mfn = load_checkpoint(
                    Classification(self.backbone, self.num_classes - 1).mfn, self.finetuned_rpn_ckpt, prefix="mfn"
                )
                freeze(self.mfn)

                self.rpn = load_checkpoint(RPN(), self.finetuned_rpn_ckpt, prefix="rpn")
                freeze(self.rpn)

                self.roi = load_checkpoint(RoI(self.num_classes), self.pretrained_roi_ckpt, prefix="roi")

                self.model = Detection(self.mfn, self.rpn, self.roi)

        else:  # step 6: final/joint model
            load_checkpoint_fn = load_checkpoint
            if self.finetuned_roi_ckpt is not None:
                ckpt_path = self.finetuned_rpn_ckpt
            elif self.resume_sagemaker_from_checkpoint is not None:
                ckpt_path = self.resume_sagemaker_from_checkpoint
            else:
                ckpt_path = None
                # ignore load_checkpoint
                load_checkpoint_fn = lambda *args: args[0]

            self.mfn = load_checkpoint_fn(Classification(self.backbone, self.num_classes - 1).mfn, ckpt_path, "mfn")
            self.rpn = load_checkpoint_fn(RPN(), ckpt_path, "rpn")
            self.roi = load_checkpoint_fn(RoI(self.num_classes), ckpt_path, "roi")
            self.model = Detection(self.mfn, self.rpn, self.roi)

        return