in train.py [0:0]
def get_its(hps):
# These run for a fixed amount of time. As anchored batch is smaller, we've actually seen fewer examples
train_its = int(np.ceil(hps.n_train / (hps.n_batch_train * hvd.size())))
test_its = int(np.ceil(hps.n_test / (hps.n_batch_train * hvd.size())))
train_epoch = train_its * hps.n_batch_train * hvd.size()
# Do a full validation run
if hvd.rank() == 0:
print(hps.n_test, hps.local_batch_test, hvd.size())
assert hps.n_test % (hps.local_batch_test * hvd.size()) == 0
full_test_its = hps.n_test // (hps.local_batch_test * hvd.size())
if hvd.rank() == 0:
print("Train epoch size: " + str(train_epoch))
return train_its, test_its, full_test_its