def forward()

in src/sagemaker_defect_detection/detector.py [0:0]


    def forward(self, images, *args, **kwargs):
        if self.train_rpn:  # step 2
            images = torch.stack(images)
            features = self.mfn(images)
            features = OrderedDict({str(i): t.unsqueeze(0) for i, t in enumerate(features)})
            images = ImageList(images, [(224, 224)])
            return self.rpn(images, features, targets=kwargs.get("targets"))

        elif self.train_roi:  # step 3
            self.mfn.eval()
            self.rpn.eval()
            images = torch.stack(images)
            features = self.mfn(images)
            features = OrderedDict({str(i): t.unsqueeze(0) for i, t in enumerate(features)})
            images = ImageList(images, [(224, 224)])
            proposals, _ = self.rpn(images, features, targets=None)
            return self.roi(features, proposals, [(224, 224)], targets=kwargs.get("targets"))

        elif self.finetune_rpn:
            self.model.backbone.eval()
            self.model.roi_heads.eval()
            return self.model(images, targets=kwargs.get("targets"))

        elif self.finetune_roi:
            self.model.backbone.eval()
            self.model.rpn.eval()
            return self.model(images, targets=kwargs.get("targets"))
        else:
            return self.model(images, targets=kwargs.get("targets"))