def train_lstm()

in example/model-parallel-lstm/lstm.py [0:0]


def train_lstm(model, X_train_batch, X_val_batch,
               num_round, update_period, concat_decode, batch_size, use_loss,
               optimizer='sgd', half_life=2,max_grad_norm = 5.0, **kwargs):
    opt = mx.optimizer.create(optimizer,
                              **kwargs)

    updater = mx.optimizer.get_updater(opt)
    epoch_counter = 0
    #log_period = max(1000 / seq_len, 1)
    log_period = 28
    last_perp = 10000000.0

    for iteration in range(num_round):
        nbatch = 0
        train_nll = 0
        tic = time.time()
        for data_batch in X_train_batch:  
            batch_seq_length = data_batch.bucket_key
            m = model[batch_seq_length]
            # reset init state
            for state in m.init_states:
              state.c[:] = 0.0
              state.h[:] = 0.0
              
            head_grad = []
            if use_loss:
              ctx = m.seq_outputs[0].context
              head_grad = [mx.nd.ones((1,), ctx) for x in m.seq_outputs]

            set_rnn_inputs_from_batch(m, data_batch, batch_seq_length, batch_size)  

            m.rnn_exec.forward(is_train=True)
            # probability of each label class, used to evaluate nll
            # Change back to individual ops to see if fine grained scheduling helps.
            if not use_loss:
                if concat_decode:
                    seq_label_probs = mx.nd.choose_element_0index(m.seq_outputs[0], m.seq_labels[0])
                else:
                    seq_label_probs = [mx.nd.choose_element_0index(out, label).copyto(mx.cpu())
                                       for out, label in zip(m.seq_outputs, m.seq_labels)]
                m.rnn_exec.backward()
            else:
                seq_loss = [x.copyto(mx.cpu()) for x in m.seq_outputs]
                m.rnn_exec.backward(head_grad)

            # update epoch counter
            epoch_counter += 1
            if epoch_counter % update_period == 0:
                # updare parameters
                norm = 0.
                for idx, weight, grad, name in m.param_blocks:
                    grad /= batch_size
                    l2_norm = mx.nd.norm(grad).asscalar()
                    norm += l2_norm*l2_norm
                norm = math.sqrt(norm)
                for idx, weight, grad, name in m.param_blocks:
                    if norm > max_grad_norm:
                        grad *= (max_grad_norm / norm)
                    updater(idx, grad, weight)
                    # reset gradient to zero
                    grad[:] = 0.0
            if not use_loss:
                if concat_decode:
                    train_nll += calc_nll_concat(seq_label_probs, batch_size)
                else:
                    train_nll += calc_nll(seq_label_probs, batch_size, batch_seq_length)
            else:
                train_nll += sum([x.asscalar() for x in seq_loss]) / batch_size

            nbatch += batch_size
            toc = time.time()
            if epoch_counter % log_period == 0:
                print("Iter [%d] Train: Time: %.3f sec, NLL=%.3f, Perp=%.3f" % (
                    epoch_counter, toc - tic, train_nll / nbatch, np.exp(train_nll / nbatch)))
        # end of training loop
        toc = time.time()
        print("Iter [%d] Train: Time: %.3f sec, NLL=%.3f, Perp=%.3f" % (
            iteration, toc - tic, train_nll / nbatch, np.exp(train_nll / nbatch)))

        val_nll = 0.0
        nbatch = 0
        for data_batch in X_val_batch:
            batch_seq_length = data_batch.bucket_key
            m = model[batch_seq_length]

            # validation set, reset states
            for state in m.init_states:
                state.h[:] = 0.0
                state.c[:] = 0.0

            set_rnn_inputs_from_batch(m, data_batch, batch_seq_length, batch_size)
            m.rnn_exec.forward(is_train=False)

            # probability of each label class, used to evaluate nll
            if not use_loss:
                if concat_decode:
                    seq_label_probs = mx.nd.choose_element_0index(m.seq_outputs[0], m.seq_labels[0])
                else:
                    seq_label_probs = [mx.nd.choose_element_0index(out, label).copyto(mx.cpu())
                                       for out, label in zip(m.seq_outputs, m.seq_labels)]
            else:
                seq_loss = [x.copyto(mx.cpu()) for x in m.seq_outputs]

            if not use_loss:
                if concat_decode:
                    val_nll += calc_nll_concat(seq_label_probs, batch_size)
                else:
                    val_nll += calc_nll(seq_label_probs, batch_size, batch_seq_length)
            else:
                val_nll += sum([x.asscalar() for x in seq_loss]) / batch_size
            nbatch += batch_size
            
        perp = np.exp(val_nll / nbatch)
        print("Iter [%d] Val: NLL=%.3f, Perp=%.3f" % (
            iteration, val_nll / nbatch, np.exp(val_nll / nbatch)))
        if last_perp - 1.0 < perp:
            opt.lr *= 0.5
            print("Reset learning rate to %g" % opt.lr)
        last_perp = perp
        X_val_batch.reset()
        X_train_batch.reset()