in ocr/paragraph_segmentation_dcnn.py [0:0]
def main(ctx=mx.gpu()):
if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir)
train_ds = IAMDataset("form", output_data="bb", output_parse_method="form", train=True)
print("Number of training samples: {}".format(len(train_ds)))
test_ds = IAMDataset("form", output_data="bb", output_parse_method="form", train=False)
print("Number of testing samples: {}".format(len(test_ds)))
train_data = gluon.data.DataLoader(train_ds.transform(augment_transform), batch_size,
shuffle=True, num_workers=int(multiprocessing.cpu_count()/2))
test_data = gluon.data.DataLoader(test_ds.transform(transform), batch_size,
shuffle=False, num_workers=int(multiprocessing.cpu_count()/2))
net = SegmentationNetwork()
net.hybridize()
net.collect_params().reset_ctx(ctx)
if restore_checkpoint_name:
net.load_parameters("{}/{}".format(checkpoint_dir, restore_checkpoint_name), ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': learning_rate})
best_test_loss = 10e5
for e in range(epochs):
train_loss = run_epoch(e, net, train_data, loss_function=loss_function, log_dir=log_dir,
trainer=trainer, print_name="train", is_train=True)
test_loss = run_epoch(e, net, test_data, loss_function=loss_function, log_dir=log_dir,
trainer=trainer, print_name="test", is_train=False)
if test_loss < best_test_loss:
print("Saving network, previous best test loss {:.6f}, current test loss {:.6f}".format(best_test_loss, test_loss))
net.save_parameters(os.path.join(checkpoint_dir, checkpoint_name))
best_test_loss = test_loss
if e % print_every_n == 0 and e > 0:
print("Epoch {0}, train_loss {1:.6f}, test_loss {2:.6f}".format(e, train_loss, test_loss))