def run_epoch()

in sagemaker/src/handwriting_line_recognition.py [0:0]


def run_epoch(e, network, dataloader, trainer, log_dir, print_name, is_train):
    '''
    Run one epoch to train or test the CNN-biLSTM network
    
    Parameters
    ----------
        
    e: int
        The epoch number
    network: nn.Gluon.HybridSequential
        The CNN-biLSTM network
    dataloader: gluon.data.DataLoader
        The train or testing dataloader that is wrapped around the iam_dataset
    
    log_dir: Str
        The directory to store the log files for mxboard
    print_name: Str
        Name to print for associating with the data. usually this will be "train" and "test"
    
    is_train: bool
        Boolean to indicate whether or not the network should be updated. is_train should only be set to true for the training data
    Returns
    -------
    
    epoch_loss: float
        The loss of the current epoch
    '''

    total_loss = [nd.zeros(1, ctx_) for ctx_ in ctx]
    for i, (x_, y_) in enumerate(dataloader):
        X = gluon.utils.split_and_load(x_, ctx)
        Y = gluon.utils.split_and_load(y_, ctx)
        with autograd.record(train_mode=is_train):
            output = [network(x) for x in X]
            loss_ctc = [ctc_loss(o, y) for o, y in zip(output, Y)]

        if is_train:
            [l.backward() for l in loss_ctc]
            trainer.step(x_.shape[0])

        if i == 0 and e % send_image_every_n == 0 and e > 0:
            predictions = output[0][:1].softmax().topk(axis=2).asnumpy()
            decoded_text = decode(predictions)
            image = X[0][:1].asnumpy()
            image = image * 0.15926149044640417 + 0.942532484060557            
            output_image = draw_text_on_image(image, decoded_text)
            print("{} first decoded text = {}".format(print_name, decoded_text[0]))

        for i, l in enumerate(loss_ctc):
            total_loss[i] += l.mean()

    epoch_loss = float(sum([tl.asscalar() for tl in total_loss]))/(len(dataloader)*len(ctx))

    return epoch_loss