def main()

in ocr/paragraph_segmentation_dcnn.py [0:0]


def main(ctx=mx.gpu()):
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    train_ds = IAMDataset("form", output_data="bb", output_parse_method="form", train=True)
    print("Number of training samples: {}".format(len(train_ds)))

    test_ds = IAMDataset("form", output_data="bb", output_parse_method="form", train=False)
    print("Number of testing samples: {}".format(len(test_ds)))

    train_data = gluon.data.DataLoader(train_ds.transform(augment_transform), batch_size,
                                       shuffle=True, num_workers=int(multiprocessing.cpu_count()/2))
    test_data = gluon.data.DataLoader(test_ds.transform(transform), batch_size,
                                      shuffle=False, num_workers=int(multiprocessing.cpu_count()/2))

    net = SegmentationNetwork()
    net.hybridize()
    net.collect_params().reset_ctx(ctx)
    if restore_checkpoint_name:
        net.load_parameters("{}/{}".format(checkpoint_dir, restore_checkpoint_name), ctx=ctx)

    trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': learning_rate})
    best_test_loss = 10e5
    for e in range(epochs):
        train_loss = run_epoch(e, net, train_data, loss_function=loss_function, log_dir=log_dir, 
                               trainer=trainer, print_name="train", is_train=True)
        test_loss = run_epoch(e, net, test_data, loss_function=loss_function, log_dir=log_dir,
                              trainer=trainer, print_name="test", is_train=False)
        if test_loss < best_test_loss:
            print("Saving network, previous best test loss {:.6f}, current test loss {:.6f}".format(best_test_loss, test_loss))
            net.save_parameters(os.path.join(checkpoint_dir, checkpoint_name))
            best_test_loss = test_loss
        if e % print_every_n == 0 and e > 0:
            print("Epoch {0}, train_loss {1:.6f}, test_loss {2:.6f}".format(e, train_loss, test_loss))