in src/sagemaker_defect_detection/detector.py [0:0]
def forward(self, images, *args, **kwargs):
if self.train_rpn: # step 2
images = torch.stack(images)
features = self.mfn(images)
features = OrderedDict({str(i): t.unsqueeze(0) for i, t in enumerate(features)})
images = ImageList(images, [(224, 224)])
return self.rpn(images, features, targets=kwargs.get("targets"))
elif self.train_roi: # step 3
self.mfn.eval()
self.rpn.eval()
images = torch.stack(images)
features = self.mfn(images)
features = OrderedDict({str(i): t.unsqueeze(0) for i, t in enumerate(features)})
images = ImageList(images, [(224, 224)])
proposals, _ = self.rpn(images, features, targets=None)
return self.roi(features, proposals, [(224, 224)], targets=kwargs.get("targets"))
elif self.finetune_rpn:
self.model.backbone.eval()
self.model.roi_heads.eval()
return self.model(images, targets=kwargs.get("targets"))
elif self.finetune_roi:
self.model.backbone.eval()
self.model.rpn.eval()
return self.model(images, targets=kwargs.get("targets"))
else:
return self.model(images, targets=kwargs.get("targets"))