in sagemaker/src/word_and_line_segmentation.py [0:0]
def run_epoch(e, network, dataloader, trainer, print_name, is_train, update_metric):
'''
Run one epoch to train or test the SSD network
Parameters
----------
e: int
The epoch number
network: nn.Gluon.HybridSequential
The SSD network
dataloader: gluon.data.DataLoader
The train or testing dataloader that is wrapped around the iam_dataset
print_name: Str
Name to print for associating with the data. usually this will be "train" and "test"
is_train: bool
Boolean to indicate whether or not the CNN should be updated. is_train should only be set to true for the training data
Returns
-------
network: gluon.nn.HybridSequential
The class predictor network
'''
total_losses = [0 for ctx_i in ctx]
for i, (X, Y) in enumerate(dataloader):
X = gluon.utils.split_and_load(X, ctx)
Y = gluon.utils.split_and_load(Y, ctx)
with autograd.record(train_mode=is_train):
losses = []
for x, y in zip(X, Y):
default_anchors, class_predictions, box_predictions = network(x)
box_target, box_mask, cls_target = network.training_targets(default_anchors, class_predictions, y)
# losses
loss_class = cls_loss(class_predictions, cls_target)
loss_box = box_loss(box_predictions, box_target, box_mask)
# sum all losses
loss = loss_class + loss_box
losses.append(loss)
if is_train:
for loss in losses:
loss.backward()
step_size = 0
for x in X:
step_size += x.shape[0]
trainer.step(step_size)
for index, loss in enumerate(losses):
total_losses[index] += loss.mean().asscalar()
if update_metric:
cls_metric.update([cls_target], [nd.transpose(class_predictions, (0, 2, 1))])
box_metric.update([box_target], [box_predictions * box_mask])
if i == 0 and e % send_image_every_n == 0 and e > 0:
cls_probs = nd.SoftmaxActivation(nd.transpose(class_predictions, (0, 2, 1)), mode='channel')
output_image, number_of_bbs = generate_output_image(box_predictions, default_anchors,
cls_probs, box_target, box_mask,
cls_target, x, y)
print("Number of predicted {} BBs = {}".format(print_name, number_of_bbs))
total_loss = 0
for loss in total_losses:
total_loss += loss / (len(dataloader)*len(total_losses))
return total_loss