in word_language_model/word_language_model.py [0:0]
def train(epochs, ctx):
best_val = float("Inf")
for epoch in range(epochs):
total_L = 0.0
start_time = time.time()
hidden_states = [
model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size // len(ctx), ctx=ctx[i])
for i in range(len(ctx))
]
for ibatch, i in enumerate(range(0, train_data.shape[0] - 1, args.bptt)):
# get data batch from the training data
data_batch, target_batch = get_batch(train_data, i)
# For RNN we can do within batch multi-device parallelization
data = gluon.utils.split_and_load(data_batch, ctx_list=ctx, batch_axis=1)
target = gluon.utils.split_and_load(target_batch, ctx_list=ctx, batch_axis=1)
Ls = []
for (d, t) in zip(data, target):
# get corresponding hidden state then update hidden
hidden = detach(hidden_states[d.context.device_id])
with autograd.record():
output, hidden = model(d, hidden)
L = loss(output, t.reshape((-1,)))
L.backward()
Ls.append(L)
# write back to the record
hidden_states[d.context.device_id] = hidden
for c in ctx:
grads = [i.grad(c) for i in model.collect_params().values()]
# Here gradient is for the whole batch.
# So we multiply max_norm by batch_size and bptt size to balance it.
# Also this utility function needs to be applied within the same context
gluon.utils.clip_global_norm(grads, args.clip * args.bptt * args.batch_size / len(ctx))
trainer.step(args.batch_size)
for L in Ls:
total_L += mx.nd.sum(L).asscalar()
if ibatch % args.log_interval == 0 and ibatch > 0:
cur_L = total_L / args.bptt / args.batch_size / args.log_interval
logging.info('[Epoch %d Batch %d] loss %.2f, ppl %.2f' % (
epoch, ibatch, cur_L, math.exp(cur_L)))
total_L = 0.0
val_L = eval(val_data, ctx)
logging.info('[Epoch %d] time cost %.2fs, valid loss %.2f, valid ppl %.2f' % (
epoch, time.time() - start_time, val_L, math.exp(val_L)))
if val_L < best_val:
best_val = val_L
test_L = eval(test_data, ctx)
model.collect_params().save('model.params')
logging.info('test loss %.2f, test ppl %.2f' % (test_L, math.exp(test_L)))
else:
args.lr = args.lr * 0.25
trainer._init_optimizer('sgd',
{
'learning_rate': args.lr,
'momentum': 0,
'wd': 0
}
)
model.collect_params().load('model.params', ctx)