in sagemaker/src/handwriting_line_recognition.py [0:0]
def run_epoch(e, network, dataloader, trainer, log_dir, print_name, is_train):
'''
Run one epoch to train or test the CNN-biLSTM network
Parameters
----------
e: int
The epoch number
network: nn.Gluon.HybridSequential
The CNN-biLSTM network
dataloader: gluon.data.DataLoader
The train or testing dataloader that is wrapped around the iam_dataset
log_dir: Str
The directory to store the log files for mxboard
print_name: Str
Name to print for associating with the data. usually this will be "train" and "test"
is_train: bool
Boolean to indicate whether or not the network should be updated. is_train should only be set to true for the training data
Returns
-------
epoch_loss: float
The loss of the current epoch
'''
total_loss = [nd.zeros(1, ctx_) for ctx_ in ctx]
for i, (x_, y_) in enumerate(dataloader):
X = gluon.utils.split_and_load(x_, ctx)
Y = gluon.utils.split_and_load(y_, ctx)
with autograd.record(train_mode=is_train):
output = [network(x) for x in X]
loss_ctc = [ctc_loss(o, y) for o, y in zip(output, Y)]
if is_train:
[l.backward() for l in loss_ctc]
trainer.step(x_.shape[0])
if i == 0 and e % send_image_every_n == 0 and e > 0:
predictions = output[0][:1].softmax().topk(axis=2).asnumpy()
decoded_text = decode(predictions)
image = X[0][:1].asnumpy()
image = image * 0.15926149044640417 + 0.942532484060557
output_image = draw_text_on_image(image, decoded_text)
print("{} first decoded text = {}".format(print_name, decoded_text[0]))
for i, l in enumerate(loss_ctc):
total_loss[i] += l.mean()
epoch_loss = float(sum([tl.asscalar() for tl in total_loss]))/(len(dataloader)*len(ctx))
return epoch_loss