def train()

in notebooks/classify_mxnet.py [0:0]


def train(current_host, channel_input_dirs, hyperparameters, hosts, num_gpus):
    # SageMaker passes num_cpus, num_gpus and other args we can use to tailor training to
    # the current container environment, but here we just use simple cpu context.
    ctx = [mx.gpu()] if num_gpus > 0 else [mx.cpu()]

    # retrieve the hyperparameters we set in notebook (with some defaults)
    batch_size = hyperparameters.get('batch_size', 100)
    epochs = hyperparameters.get('epochs', 10)
    learning_rate = hyperparameters.get('learning_rate', 0.1)
    momentum = hyperparameters.get('momentum', 0.9)
    wd = hyperparameters.get('wd', 0.001)
    log_interval = hyperparameters.get('log_interval', 100)

    # load training and validation data
    # we use the gluon.data.vision.MNIST class because of its built in mnist pre-processing logic,
    # but point it at the location where SageMaker placed the data files, so it doesn't download them again.
    training_dir = channel_input_dirs['training']
    valid_dir = channel_input_dirs['validation']
    train_data = get_train_data(training_dir, batch_size)
    val_data = get_val_data(valid_dir, batch_size)

    # define the network
    net = define_network()

    # Collect all parameters from net and its children, then initialize them.
    net.output.initialize(init.Xavier(), ctx=ctx)
    net.output.collect_params().setattr('lr_mult', 10)
    net.collect_params().reset_ctx(ctx)
    
    # Trainer is for updating parameters with gradient.
    if len(hosts) == 1:
        kvstore = 'device' if num_gpus > 0 else 'local'
    else:
        kvstore = 'dist_device_sync' if num_gpus > 0 else 'dist_sync'

    trainer = gluon.Trainer(net.collect_params(), 'adam',
                            {'learning_rate': learning_rate, 'wd': wd},
                            kvstore=kvstore)
    metric = mx.metric.Accuracy()
    loss = gluon.loss.SoftmaxCrossEntropyLoss()

    # shard the training data in case we are doing distributed training. Alternatively to splitting in memory,
    # the data could be pre-split in S3 and use ShardedByS3Key to do distributed training.
    if len(hosts) > 1:
        train_data = [x for x in train_data]
        shard_size = len(train_data) // len(hosts)
        for i, host in enumerate(hosts):
            if host == current_host:
                start = shard_size * i
                end = start + shard_size
                break

        train_data = train_data[start:end]
    
    net.hybridize()
    
    for epoch in range(epochs):
        # reset data iterator and metric at begining of epoch.
        metric.reset()
        btic = time.time()
        for i, (data, label) in enumerate(train_data):
            # Copy data to ctx if necessary
            data = data.as_in_context(ctx[0])
            label = label.as_in_context(ctx[0])
            # Start recording computation graph with record() section.
            # Recorded graphs can then be differentiated with backward.
            with autograd.record():
                output = net(data)
                L = loss(output, label)
                L.backward()
            # take a gradient step with batch_size equal to data.shape[0]
            trainer.step(data.shape[0])
            # update metric at last.
            metric.update([label], [output])

            if i % log_interval == 0 and i > 0:
                name, acc = metric.get()
                print('[Epoch %d Batch %d] Training: %s=%f, %f samples/s' %
                      (epoch, i, name, acc, batch_size / (time.time() - btic)))

            btic = time.time()

        name, acc = metric.get()
        print('[Epoch %d] Training: %s=%f' % (epoch, name, acc))

        name, val_acc = test(ctx, net, val_data)
        print('[Epoch %d] Validation: %s=%f' % (epoch, name, val_acc))

    return net