in denoiser/pretrained.py [0:0]
def get_model(args):
"""
Load local model package or torchhub pre-trained model.
"""
if args.model_path:
logger.info("Loading model from %s", args.model_path)
pkg = torch.load(args.model_path, 'cpu')
if 'model' in pkg:
if 'best_state' in pkg:
pkg['model']['state'] = pkg['best_state']
model = deserialize_model(pkg['model'])
else:
model = deserialize_model(pkg)
elif args.dns64:
logger.info("Loading pre-trained real time H=64 model trained on DNS.")
model = dns64()
elif args.master64:
logger.info("Loading pre-trained real time H=64 model trained on DNS and Valentini.")
model = master64()
elif args.valentini_nc:
logger.info("Loading pre-trained H=64 model trained on Valentini.")
model = valentini_nc()
else:
logger.info("Loading pre-trained real time H=48 model trained on DNS.")
model = dns48()
logger.debug(model)
return model