in MMS/dicom_featurization_service.py [0:0]
def initialize(self, context):
# Load the model and mapping file to perform infernece.
properties = context.system_properties
model_dir = properties.get("model_dir")
# Read model file if provided
model_file_path = os.path.join(model_dir, "model.pth")
if os.path.isfile(model_file_path):
model = torch.load(model_file_path)
else:
model = models.densenet121(pretrained=True)
model = model._modules.get('features')
model.add_module("last_relu", nn.ReLU())
model.add_module("glob_pool", nn.AdaptiveAvgPool2d((1, 1)))
model.add_module("last_flatten", nn.Flatten())
model = model.to(self.device)
model.eval()
self.model = model
self.initialized = True