in src/similarity/inference.py [0:0]
def model_fn(model_dir):
model_info = {}
with open(os.path.join(model_dir, 'model_info.pth'), 'rb') as f:
model_info = torch.load(f)
print('model_info: {}'.format(model_info))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info('Current device: {}'.format(device))
model = CNN(similarity_dims=model_info['simililarity-dims'])
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
model.load_state_dict(torch.load(f))
model.eval()
return model