in research/gam/gam/trainer/trainer_classification.py [0:0]
def train(self, data, session=None, **kwargs):
"""Train the classification model on the provided dataset.
Args:
data: A CotrainDataset object.
session: A TensorFlow session or None.
**kwargs: Other keyword arguments.
Returns:
best_test_acc: A float representing the test accuracy at the iteration
where the validation accuracy is maximum.
best_val_acc: A float representing the best validation accuracy.
"""
summary_writer = kwargs['summary_writer']
logging.info('Training classifier...')
if not self.is_initialized:
self.is_initialized = True
else:
if self.weight_decay_update is not None:
session.run(self.weight_decay_update)
logging.info('New weight decay value: %f',
session.run(self.weight_decay_var))
# Reset the optimizer state (e.g., momentum).
session.run(self.reset_optimizer)
if not self.warm_start:
# Re-initialize variables.
initializers = [v.initializer for v in self.variables.values()]
initializers.append(self.global_step.initializer)
session.run(initializers)
# Construct data iterator.
logging.info('Training classifier with %d samples...', data.num_train())
train_indices = data.get_indices_train()
unlabeled_indices = data.get_indices_unlabeled()
val_indices = data.get_indices_val()
test_indices = data.get_indices_test()
# Create an iterator for labeled samples for the supervised term.
data_iterator_train = batch_iterator(
train_indices,
batch_size=self.batch_size,
shuffle=True,
allow_smaller_batch=False,
repeat=True)
# Create an iterator for unlabeled samples for the VAT loss term.
data_iterator_unlabeled = batch_iterator(
unlabeled_indices,
batch_size=self.batch_size,
shuffle=True,
allow_smaller_batch=False,
repeat=True)
# Create iterators for ll, lu, uu pairs of samples for the agreement term.
if self.use_graph:
pair_ll_iterator = self.edge_iterator(
data, batch_size=self.num_pairs_reg, labeling='ll')
pair_lu_iterator = self.edge_iterator(
data, batch_size=self.num_pairs_reg, labeling='lu')
pair_uu_iterator = self.edge_iterator(
data, batch_size=self.num_pairs_reg, labeling='uu')
else:
pair_ll_iterator = self.pair_iterator(train_indices, train_indices,
self.num_pairs_reg, data)
pair_lu_iterator = self.pair_iterator(train_indices, unlabeled_indices,
self.num_pairs_reg, data)
pair_uu_iterator = self.pair_iterator(unlabeled_indices,
unlabeled_indices,
self.num_pairs_reg, data)
step = 0
iter_below_tol = 0
min_num_iter = self.min_num_iter
has_converged = step >= self.max_num_iter
prev_loss_val = np.inf
best_test_acc = -1
best_val_acc = -1
checkpoint_saved = False
while not has_converged:
feed_dict = self._construct_feed_dict(
data_iterator=data_iterator_train,
split='train',
pair_ll_iterator=pair_ll_iterator,
pair_lu_iterator=pair_lu_iterator,
pair_uu_iterator=pair_uu_iterator,
data_iterator_unlabeled=data_iterator_unlabeled)
if self.enable_summaries and step % self.summary_step == 0:
loss_val, summary, iter_cls_total, _ = session.run(
[self.loss_op, self.summary_op, self.iter_cls_total, self.train_op],
feed_dict=feed_dict)
summary_writer.add_summary(summary, iter_cls_total)
summary_writer.flush()
else:
loss_val, _ = session.run((self.loss_op, self.train_op),
feed_dict=feed_dict)
# Log the loss, if necessary.
if step % self.logging_step == 0:
logging.info('Classification step %6d | Loss: %10.4f', step, loss_val)
# Evaluate, if necessary.
if step % self.eval_step == 0:
val_acc = self._evaluate(val_indices, 'val', session, summary_writer)
test_acc = self._evaluate(test_indices, 'test', session, summary_writer)
if step % self.logging_step == 0 or val_acc > best_val_acc:
logging.info(
'Classification step %6d | Loss: %10.4f | val_acc: %10.4f | '
'test_acc: %10.4f', step, loss_val, val_acc, test_acc)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
if self.checkpoint_path:
self.saver.save(
session, self.checkpoint_path, write_meta_graph=False)
checkpoint_saved = True
# Go for at least num_iter_after_best_val more iterations.
min_num_iter = max(self.min_num_iter,
step + self.num_iter_after_best_val)
logging.info(
'Achieved best validation. '
'Extending to at least %d iterations...', min_num_iter)
step += 1
has_converged, iter_below_tol = self.check_convergence(
prev_loss_val,
loss_val,
step,
self.max_num_iter,
iter_below_tol,
min_num_iter=min_num_iter)
session.run(self.iter_cls_total_update)
prev_loss_val = loss_val
# Return to the best model.
if checkpoint_saved:
logging.info('Restoring best model...')
self.saver.restore(session, self.checkpoint_path)
return best_test_acc, best_val_acc