def model_fn()

in containers/Shoot/CNN/host.py [0:0]


def model_fn(model_dir):
    print("loading from %s"%(model_dir))
    with open('%s/hyperparameters.json'%(model_dir), 'r') as fp:
        hyperparameters = json.load(fp)
    
    net=model(
        depth=int(hyperparameters.get("depth",2)),
        width=int(hyperparameters.get("width",3)),
    )
    try:
        print("trying to load float16")
        net.cast("float16") 
        net.collect_params().load("%s/model-0000.params"%(model_dir), ctx)
    except Exception as e: 
        print(e)
        print("trying to load float32")
        net.cast("float32") 
        net.collect_params().load("%s/model-0000.params"%(model_dir), ctx)

    net.cast("float32")

    return net