in example/model-parallel-lstm/lstm.py [0:0]
def train_lstm(model, X_train_batch, X_val_batch,
num_round, update_period, concat_decode, batch_size, use_loss,
optimizer='sgd', half_life=2,max_grad_norm = 5.0, **kwargs):
opt = mx.optimizer.create(optimizer,
**kwargs)
updater = mx.optimizer.get_updater(opt)
epoch_counter = 0
#log_period = max(1000 / seq_len, 1)
log_period = 28
last_perp = 10000000.0
for iteration in range(num_round):
nbatch = 0
train_nll = 0
tic = time.time()
for data_batch in X_train_batch:
batch_seq_length = data_batch.bucket_key
m = model[batch_seq_length]
# reset init state
for state in m.init_states:
state.c[:] = 0.0
state.h[:] = 0.0
head_grad = []
if use_loss:
ctx = m.seq_outputs[0].context
head_grad = [mx.nd.ones((1,), ctx) for x in m.seq_outputs]
set_rnn_inputs_from_batch(m, data_batch, batch_seq_length, batch_size)
m.rnn_exec.forward(is_train=True)
# probability of each label class, used to evaluate nll
# Change back to individual ops to see if fine grained scheduling helps.
if not use_loss:
if concat_decode:
seq_label_probs = mx.nd.choose_element_0index(m.seq_outputs[0], m.seq_labels[0])
else:
seq_label_probs = [mx.nd.choose_element_0index(out, label).copyto(mx.cpu())
for out, label in zip(m.seq_outputs, m.seq_labels)]
m.rnn_exec.backward()
else:
seq_loss = [x.copyto(mx.cpu()) for x in m.seq_outputs]
m.rnn_exec.backward(head_grad)
# update epoch counter
epoch_counter += 1
if epoch_counter % update_period == 0:
# updare parameters
norm = 0.
for idx, weight, grad, name in m.param_blocks:
grad /= batch_size
l2_norm = mx.nd.norm(grad).asscalar()
norm += l2_norm*l2_norm
norm = math.sqrt(norm)
for idx, weight, grad, name in m.param_blocks:
if norm > max_grad_norm:
grad *= (max_grad_norm / norm)
updater(idx, grad, weight)
# reset gradient to zero
grad[:] = 0.0
if not use_loss:
if concat_decode:
train_nll += calc_nll_concat(seq_label_probs, batch_size)
else:
train_nll += calc_nll(seq_label_probs, batch_size, batch_seq_length)
else:
train_nll += sum([x.asscalar() for x in seq_loss]) / batch_size
nbatch += batch_size
toc = time.time()
if epoch_counter % log_period == 0:
print("Iter [%d] Train: Time: %.3f sec, NLL=%.3f, Perp=%.3f" % (
epoch_counter, toc - tic, train_nll / nbatch, np.exp(train_nll / nbatch)))
# end of training loop
toc = time.time()
print("Iter [%d] Train: Time: %.3f sec, NLL=%.3f, Perp=%.3f" % (
iteration, toc - tic, train_nll / nbatch, np.exp(train_nll / nbatch)))
val_nll = 0.0
nbatch = 0
for data_batch in X_val_batch:
batch_seq_length = data_batch.bucket_key
m = model[batch_seq_length]
# validation set, reset states
for state in m.init_states:
state.h[:] = 0.0
state.c[:] = 0.0
set_rnn_inputs_from_batch(m, data_batch, batch_seq_length, batch_size)
m.rnn_exec.forward(is_train=False)
# probability of each label class, used to evaluate nll
if not use_loss:
if concat_decode:
seq_label_probs = mx.nd.choose_element_0index(m.seq_outputs[0], m.seq_labels[0])
else:
seq_label_probs = [mx.nd.choose_element_0index(out, label).copyto(mx.cpu())
for out, label in zip(m.seq_outputs, m.seq_labels)]
else:
seq_loss = [x.copyto(mx.cpu()) for x in m.seq_outputs]
if not use_loss:
if concat_decode:
val_nll += calc_nll_concat(seq_label_probs, batch_size)
else:
val_nll += calc_nll(seq_label_probs, batch_size, batch_seq_length)
else:
val_nll += sum([x.asscalar() for x in seq_loss]) / batch_size
nbatch += batch_size
perp = np.exp(val_nll / nbatch)
print("Iter [%d] Val: NLL=%.3f, Perp=%.3f" % (
iteration, val_nll / nbatch, np.exp(val_nll / nbatch)))
if last_perp - 1.0 < perp:
opt.lr *= 0.5
print("Reset learning rate to %g" % opt.lr)
last_perp = perp
X_val_batch.reset()
X_train_batch.reset()