def train()

in containers/Shoot/CNN/train.py [0:0]


def train(hyperparameters, hosts, num_gpus, **kwargs):
    try:
        _ = mx.nd.array([1], ctx=mx.gpu(0))
        ctx = [mx.gpu(i) for i in range(num_gpus)]
        print("using GPU")
        DTYPE = "float16"
        host_ctx = mx.cpu_pinned(0)
    except mx.MXNetError:
        ctx = [mx.cpu()]
        print("using CPU")
        DTYPE = "float32"
        host_ctx = mx.cpu(0)

    model_dir = os.environ.get("SM_CHANNEL_MODEL")
    if model_dir:
        print("using prebuild model")
        shutil.unpack_archive("%s/model.tar.gz" % (model_dir), model_dir)
        with open('%s/hyperparameters.json' % (model_dir), 'r') as fp:
            saved_hyperparameters = json.load(fp)

        net = model(
            depth=int(saved_hyperparameters.get("depth", 2)),
            width=int(saved_hyperparameters.get("width", 3)),
        )
        try:
            print("trying to load float16")
            net.cast("float16")
            net.collect_params().load("%s/model-0000.params" % (model_dir), ctx)
        except Exception as e:
            print(e)
            print("trying to load float32")
            net.cast("float32")
            net.collect_params().load("%s/model-0000.params" % (model_dir), ctx)
        net.cast(DTYPE)
    else:
        print("building model from scratch")
        net = model(
            depth=int(hyperparameters.get("depth", 2)),
            width=int(hyperparameters.get("width", 3)),
        )
        net.cast(DTYPE)
    net.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
    net.hybridize()
    print(net)

    dice = DiceLoss()
    dice.cast(DTYPE)
    dice.hybridize()

    trainer = gluon.Trainer(
        net.collect_params_layers(2) if model_dir else net.collect_params(),
        'adam',
        {
            "multi_precision": (DTYPE == 'float16'),
            'learning_rate': float(hyperparameters.get("learning_rate", .001))
        })
    train_iter, test_iter = get_data(int(hyperparameters.get("batch_size", 8)), DTYPE, host_ctx)

    Loss = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)

    best = float("inf")
    warm_up = int(hyperparameters.get("warm_up", 30))
    patience = int(hyperparameters.get("patience", 10))
    wait = 0

    for e in range(hyperparameters.get("epochs", 11)):
        print("Epoch %s" % (e))
        val_loss = 0
        st = time.time()
        training_count = 0
        testing_count = 0
        training_loss = 0

        for batch in train_iter:
            batch_size = batch.data[0].shape[0]
            training_count += batch_size
            data = gluon.utils.split_and_load(batch.data[0].astype(DTYPE), ctx)
            label = gluon.utils.split_and_load(batch.label[0].astype(DTYPE).reshape((batch_size, -1)), ctx)
            mask = gluon.utils.split_and_load(batch.label[1].astype(DTYPE).reshape((batch_size, -1)), ctx)

            with autograd.record():
                output = [net(x) for x in data]
                losses = [-dice(x, y, z) for x, y, z in zip(output, label, mask)]
            for loss in losses:
                loss.backward()
            trainer.step(batch_size)
            training_loss += sum(loss.sum().asscalar() for loss in losses)

        for batch in test_iter:
            batch_size = batch.data[0].shape[0]
            testing_count += batch_size

            data = gluon.utils.split_and_load(batch.data[0].astype(DTYPE), ctx)
            label = gluon.utils.split_and_load(batch.label[0].astype(DTYPE).reshape((batch_size, -1)), ctx)
            mask = gluon.utils.split_and_load(batch.label[1].astype(DTYPE).reshape((batch_size, -1)), ctx)

            output = [net(x) for x in data]
            losses = [-dice(x, y, z) for x, y, z in zip(output, label, mask)]

            val_loss += sum(loss.sum().asscalar() for loss in losses)

        et = time.time()
        print("Hyperparameters: %s;" % (hyperparameters))
        print("Training loss: %s;" % (-training_loss / training_count))
        print("Testing loss: %s;" % (-val_loss / (testing_count)))
        print("Throughput=%2.2f;" % ((training_count + testing_count) / (et - st)))
        print("Validation Loss=%2.2f;" % val_loss)
        print("Best=%2.2f;" % best)

        if e >= warm_up:
            if val_loss < best:
                print("best model: %s;" % (-val_loss / (testing_count)))
                save(net, hyperparameters)
                best = val_loss
                wait = 0
            else:
                wait += 1
        if wait > patience:
            print("stoping early")
            break
        train_iter.reset()
        test_iter.reset()