in containers/Shoot/CNN/host.py [0:0]
def model_fn(model_dir):
print("loading from %s"%(model_dir))
with open('%s/hyperparameters.json'%(model_dir), 'r') as fp:
hyperparameters = json.load(fp)
net=model(
depth=int(hyperparameters.get("depth",2)),
width=int(hyperparameters.get("width",3)),
)
try:
print("trying to load float16")
net.cast("float16")
net.collect_params().load("%s/model-0000.params"%(model_dir), ctx)
except Exception as e:
print(e)
print("trying to load float32")
net.cast("float32")
net.collect_params().load("%s/model-0000.params"%(model_dir), ctx)
net.cast("float32")
return net