def model_fn()

in source/monai_skin_cancer.py [0:0]


def model_fn(model_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = densenet121(
        spatial_dims=2,
        in_channels=3,
        out_channels=7
    )    
    with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
        model.load_state_dict(torch.load(f))
    return model.to(device)