def train()

in word_language_model/word_language_model_train.py [0:0]


def train(epochs, ctx):
    best_val = float("Inf")

    for epoch in range(epochs):
        total_L = 0.0
        cur_L = 0.0
        tic = time.time()
        hidden_states = [
            model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size // len(ctx), ctx=ctx[i])
            for i in range(len(ctx))
        ]
        btic = time.time()
        for ibatch, i in enumerate(range(0, train_data.shape[0] - 1, args.bptt)):
            # get data batch from the training data
            data_batch, target_batch = get_batch(train_data, i)
            # For RNN we can do within batch multi-device parallelization
            data = gluon.utils.split_and_load(data_batch, ctx_list=ctx, batch_axis=1)
            target = gluon.utils.split_and_load(target_batch, ctx_list=ctx, batch_axis=1)
            Ls = []
            for (d, t) in zip(data, target):
                # get corresponding hidden state then update hidden
                hidden = detach(hidden_states[d.context.device_id])
                with autograd.record():
                    output, hidden = model(d, hidden)
                    L = loss(output, t.reshape((-1,)))
                    L.backward()
                    Ls.append(L)
                # write back to the record
                hidden_states[d.context.device_id] = hidden

            for c in ctx:
                grads = [i.grad(c) for i in model.collect_params().values()]
                # Here gradient is for the whole batch.
                # So we multiply max_norm by batch_size and bptt size to balance it.
                # Also this utility function needs to be applied within the same context
                gluon.utils.clip_global_norm(grads, args.clip * args.bptt * args.batch_size / len(ctx))

            trainer.step(args.batch_size)
            for L in Ls:
                total_L += mx.nd.sum(L).asscalar()

            if ibatch % args.log_interval == 0 and ibatch > 0:
                cur_L = total_L / args.bptt / args.batch_size / args.log_interval
                logging.info('[Epoch %d Batch %d] Speed: %f samples/sec loss %.2f, ppl %.2f' % (
                    epoch, ibatch, args.batch_size/(time.time()-btic), cur_L, math.exp(cur_L)))
                total_L = 0.0
            btic = time.time()

        logging.info('[Epoch %d] train loss %.2f, train ppl %.2f' % (epoch, cur_L, math.exp(cur_L)))
        logging.info('[Epoch %d] time cost %.2f'%(epoch, time.time()-tic))
        val_L = eval(val_data, ctx)
        logging.info('[Epoch %d] valid loss %.2f, valid ppl %.2f' % (epoch, val_L, math.exp(val_L)))

        if val_L < best_val:
            best_val = val_L
            # test_L = eval(test_data, ctx)
            model.collect_params().save('model.params')
            # logging.info('test loss %.2f, test ppl %.2f' % (test_L, math.exp(test_L)))
        else:
            args.lr = args.lr * 0.25
            trainer._init_optimizer('sgd',
                {
                    'learning_rate': args.lr,
                    'momentum': 0,
                    'wd': 0
                }
            )
            model.collect_params().load('model.params', ctx)