in train.py [0:0]
def train(sess, model, hps, logdir, visualise):
_print(hps)
_print('Starting training. Logging to', logdir)
_print('epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg')
# Train
sess.graph.finalize()
n_processed = 0
n_images = 0
train_time = 0.0
test_loss_best = 999999
if hvd.rank() == 0:
train_logger = ResultLogger(logdir + "train.txt", **hps.__dict__)
test_logger = ResultLogger(logdir + "test.txt", **hps.__dict__)
tcurr = time.time()
for epoch in range(1, hps.epochs):
t = time.time()
train_results = []
for it in range(hps.train_its):
# Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs.
lr = hps.lr * min(1., n_processed /
(hps.n_train * hps.epochs_warmup))
# Run a training step synchronously.
_t = time.time()
train_results += [model.train(lr)]
if hps.verbose and hvd.rank() == 0:
_print(n_processed, time.time()-_t, train_results[-1])
sys.stdout.flush()
# Images seen wrt anchor resolution
n_processed += hvd.size() * hps.n_batch_train
# Actual images seen at current resolution
n_images += hvd.size() * hps.local_batch_train
train_results = np.mean(np.asarray(train_results), axis=0)
dtrain = time.time() - t
ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain
train_time += dtrain
if hvd.rank() == 0:
train_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int(
train_time), **process_results(train_results))
if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0:
test_results = []
msg = ''
t = time.time()
# model.polyak_swap()
if epoch % hps.epochs_full_valid == 0:
# Full validation run
for it in range(hps.full_test_its):
test_results += [model.test()]
test_results = np.mean(np.asarray(test_results), axis=0)
if hvd.rank() == 0:
test_logger.log(epoch=epoch, n_processed=n_processed,
n_images=n_images, **process_results(test_results))
# Save checkpoint
if test_results[0] < test_loss_best:
test_loss_best = test_results[0]
model.save(logdir+"model_best_loss.ckpt")
msg += ' *'
dtest = time.time() - t
# Sample
t = time.time()
if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0:
visualise(epoch)
dsample = time.time() - t
if hvd.rank() == 0:
dcurr = time.time() - tcurr
tcurr = time.time()
_print(epoch, n_processed, n_images, "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format(
ips, dtrain, dtest, dsample, dcurr), train_results, test_results, msg)
# model.polyak_swap()
if hvd.rank() == 0:
_print("Finished!")