in code/pretrained_model.py [0:0]
def model_fn(model_dir):
#create model
model = models.resnet18()
#traffic sign dataset has 43 classes
nfeatures = model.fc.in_features
model.fc = nn.Linear(nfeatures, 43)
#load model
weights = torch.load(f'{model_dir}/model/model.pt', map_location=lambda storage, loc: storage)
model.load_state_dict(weights)
model.eval()
model.cpu()
return model