def train_module()

in examples/mxnet_imagenet_resnet50.py [0:0]


def train_module():
    # Create input symbol
    data = mx.sym.var('data')
    if args.dtype == 'float16':
        data = mx.sym.Cast(data=data, dtype=np.float16)
        net.cast(np.float16)

    # Create output symbol
    out = net(data)
    if args.dtype == 'float16':
        out = mx.sym.Cast(data=out, dtype=np.float32)
    softmax = mx.sym.SoftmaxOutput(out, name='softmax')

    # Create model
    mod = mx.mod.Module(softmax, context=context)

    # Initialize parameters
    if args.use_pretrained:
        arg_params = {}
        for x in net.collect_params().values():
            x.reset_ctx(mx.cpu())
            arg_params[x.name] = x.data()
    else:
        arg_params = None
    aux_params = None
    mod.bind(data_shapes=train_data.provide_data,
             label_shapes=train_data.provide_label)
    mod.init_params(initializer, arg_params=arg_params, aux_params=aux_params)

    # Horovod: fetch and broadcast parameters
    (arg_params, aux_params) = mod.get_params()
    if arg_params is not None:
        hvd.broadcast_parameters(arg_params, root_rank=0)
    if aux_params is not None:
        hvd.broadcast_parameters(aux_params, root_rank=0)
    mod.set_params(arg_params=arg_params, aux_params=aux_params)

    # Create optimizer
    # Note that when using Module API, we need to specify rescale_grad since
    # we create optimizer first and wrap it with DistributedOptimizer. For
    # Gluon API, it is handled in Trainer.step() function so there is no need
    # to specify rescale_grad (see above train_gluon() function). 
    optimizer_params = {'wd': args.wd,
                        'momentum': args.momentum,
                        'rescale_grad': 1.0 / batch_size,
                        'lr_scheduler': lr_sched}
    if args.dtype == 'float16':
        optimizer_params['multi_precision'] = True
    opt = mx.optimizer.create('sgd', **optimizer_params)

    # Horovod: wrap optimizer with DistributedOptimizer
    dist_opt = hvd.DistributedOptimizer(opt,
                                        gradient_predivide_factor=args.gradient_predivide_factor)

    # Setup validation data and callback during training
    eval_data = None
    if args.eval_epoch:
        eval_data = val_data
    batch_callback = None
    if args.log_interval > 0 and rank == 0:
        batch_callback = mx.callback.Speedometer(batch_size * num_workers,
                                                 args.log_interval)

    epoch_callback = None
    if args.save_frequency > 0:
        epoch_callback = mx.callback.do_checkpoint(
            '%s-%d' % (args.model, rank),
            period=args.save_frequency)

    # Train model
    mod.fit(train_data,
            eval_data=eval_data,
            num_epoch=args.num_epochs,
            kvstore=None,
            batch_end_callback=batch_callback,
            epoch_end_callback=epoch_callback,
            optimizer=dist_opt)

    # Evaluate performance if not using synthetic data
    if args.use_rec:
        acc_top1 = mx.metric.Accuracy()
        acc_top5 = mx.metric.TopKAccuracy(5)
        res = mod.score(val_data, [acc_top1, acc_top5])
        for name, val in res:
            logging.info('Epoch[%d] Rank[%d] Validation-%s=%f',
                         args.num_epochs - 1, rank, name, val)