in tensorflow_fold/blocks/plan.py [0:0]
def _run(self, supervisor, session):
train_feed_dict = self.train_feeds.copy()
train_fetches = {'train_op': self.train_op, 'loss': self.loss_total,
'step': self.global_step}
if self.compute_summaries: train_fetches['summaries'] = self.summaries
# The training loop is essentially the same regardless of whether
# we are passing batches by feed dict or by loom input
# tensor. There are a few minor differences:
#
# 1. By feed dict, we compute the size of the training set lazily,
# as we iterate over it in the first epoch. By input tensor, we
# calculate train_size as batch_size * batches_per_epoch.
#
# 2. By feed dict, we get the size of each batch by calling len()
# on it (since the last batch in the epoch may have less than
# batch_size elements). By input tensor, we require that every
# batch have exactly batch_size elements.
#
# 3. By feed dict we need to create batches of inputs, and feed
# them every time we run the train op (obviously).
if self.examples:
epochs, train_size = self._by_feed_dict(train_feed_dict)
else:
epochs, train_size = self._by_input_tensor(train_feed_dict)
if self.dev_examples:
# Memoize a generator of batches of (size, feed_dict) pairs.
gen_dev_batches = util.epochs(
((len(batch), self.compiler.build_feed_dict(batch))
for batch in util.group_by_batches(
self.dev_examples, self.batch_size)), shuffle=False)
# If there is an existing checkpoint in logdir, and we are
# saving the best model, calculate best_loss before doing any
# training, so we don't potentially replace a better-performing
# checkpoint with a worse one.
ckpt = tf.train.get_checkpoint_state(self.logdir)
if ckpt and ckpt.model_checkpoint_path:
_, self._best_loss, _ = self._eval_batches(
supervisor, session, next(gen_dev_batches), None, is_dev=True)
if self._best_loss is None: return # should_stop returned true
for epoch, batches in enumerate(epochs, 1):
train_loss = 0.0
for _ in batches:
if self._should_stop(supervisor): return
results = session.run(train_fetches, train_feed_dict)
train_loss += results['loss']
if self.compute_summaries:
supervisor.summary_computed(
session, results['summaries'], results['step'])
if train_size == 0:
raise ValueError('examples must be non-empty')
if self.exact_batch_sizes and epoch == 1:
if train_size < self.batch_size:
raise ValueError('when exact_batch_sizes is true, examples must have '
'at least batch_size items; %s vs. %s' % (
train_size, self.batch_size))
train_size -= train_size % self.batch_size
train_loss /= train_size
self.report_loss(results['step'], train_loss)
log_str = 'epoch:%5d train[loss: %.3e]' % (epoch, train_loss)
if self.dev_examples:
dev_size, dev_loss, dev_metrics = self._eval_batches(
supervisor, session, next(gen_dev_batches), results['step'],
is_dev=True)
if dev_size is None: return # should_stop returned true
if epoch == 1: self.log_and_print('train_size: %d dev_size: %d' %
(train_size, dev_size))
log_str += ' dev[%s]' % _eval_str(dev_size, dev_loss, dev_metrics)
self.log_and_print(log_str)
self._save_best(session, supervisor.saver, dev_loss, results['step'])
else:
if epoch == 1: self.log_and_print('train_size: %d' % train_size)
self.log_and_print(log_str)
if not self.dev_examples and self.is_chief_trainer:
save_path = os.path.join(self.logdir, 'model.ckpt')
save_fname = supervisor.saver.save(
session, save_path, global_step=results['step'])
self.log_and_print('final model saved in file: %s' % save_fname)