in research/gam/gam/trainer/trainer_agreement.py [0:0]
def train(self, data, session=None, **kwargs):
"""Train an agreement model."""
summary_writer = kwargs['summary_writer']
logging.info('Training agreement model...')
if not self.is_initialized:
self.is_initialized = True
else:
if self.weight_decay_update is not None:
session.run(self.weight_decay_update)
logging.info('New weight decay value: %f',
session.run(self.weight_decay_var))
# Construct data iterator.
if self.use_graph:
edges_train, agreement_train, edges_val, agreement_val = \
self._get_neighbors(data)
num_samples_train = agreement_train.shape[0]
num_samples_val = agreement_val.shape[0]
else:
labeled_samples = data.get_indices_train()
num_labeled_samples = len(labeled_samples)
num_samples_train = num_labeled_samples * num_labeled_samples
num_samples_val = min(
int(num_samples_train * self.ratio_val), self.max_num_samples_val)
if num_samples_train == 0:
logging.info('No samples to train agreement. Skipping...')
return None
if not self.warm_start:
# Re-initialize variables.
initializers = [v.initializer for v in self.trainable_vars]
initializers.append(self.global_step.initializer)
session.run(initializers)
# Reset the optimizer state (e.g., momentum).
session.run(self.reset_optimizer)
logging.info(
'Training agreement with %d samples and validation on %d samples.',
num_samples_train, num_samples_val)
# Create an iterator over training data pairs.
if self.use_graph:
# If we use the graph, then the training data consists of graph edges
# and the agreement (1.0 or 0.0) between them.
data_iterator_train = self._get_train_edge_iterator(
edges_train,
agreement_train,
self.batch_size,
data,
add_negatives=self.add_negative_edges)
else:
# If we don't use the graph, then the training data consists of pairs of
# labeled sampels, and the agreement (1.0 or 0.0) between them.
# Compute ratio of positives to negative samples.
labeled_samples_labels = data.get_labels(labeled_samples)
ratio_pos_to_neg = self._compute_ratio_pos_neg(labeled_samples_labels)
# Split data into train and validation.
labeled_samples_train, labeled_nodes_val = self._select_val_samples(
labeled_samples, self.ratio_val)
# Create an iterator over training data pairs.
data_iterator_train = self._pair_iterator(
labeled_samples_train, data, ratio_pos_neg=ratio_pos_to_neg)
# Start training.
best_val_acc = -1
checkpoint_saved = False
step = 0
iter_below_tol = 0
min_num_iter = self.min_num_iter
has_converged = step >= self.max_num_iter
if not has_converged:
self.num_iter_trained += 1
prev_loss_val = np.inf
while not has_converged:
feed_dict = self._construct_feed_dict(data_iterator_train, is_train=True)
if self.enable_summaries and step % self.summary_step == 0:
loss_val, summary, iter_total, _ = session.run(
[self.loss_op, self.summary_op, self.iter_agr_total, self.train_op],
feed_dict=feed_dict)
summary_writer.add_summary(summary, iter_total)
summary_writer.flush()
else:
loss_val, _ = session.run((self.loss_op, self.train_op),
feed_dict=feed_dict)
# Log the loss, if necessary.
if step % self.logging_step == 0:
logging.info('Agreement step %6d | Loss: %10.4f', step, loss_val)
# Run validation, if necessary.
if step % self.eval_step == 0:
if num_samples_val == 0:
logging.info('Skipping validation. No validation samples available.')
break
# Evaluate on the selected validation data.
if self.use_graph:
data_iterator_val = batch_iterator(
edges_val,
agreement_val,
batch_size=self.batch_size,
shuffle=False,
allow_smaller_batch=True,
repeat=False)
else:
data_iterator_val = self._pair_iterator(
labeled_nodes_val, data, ratio_pos_neg=ratio_pos_to_neg)
val_acc = self._eval_validation(data_iterator_val, num_samples_val,
session)
# Evaluate over a random choice of sample pairs, either labeled or not.
acc_random = self._eval_random_pairs(data, session)
# Evaluate the accuracy on the latest train batch. We track this to make
# sure the agreement model is able to fit the training data, but can be
# eliminated if efficiency is an issue.
acc_train, acc_0_train, acc_1_train = self._eval_train(
session, feed_dict)
if self.enable_summaries:
summary = tf.Summary()
summary.value.add(
tag='AgreementModel/train_acc', simple_value=acc_train)
summary.value.add(tag='AgreementModel/val_acc', simple_value=val_acc)
if acc_random is not None:
summary.value.add(
tag='AgreementModel/random_acc', simple_value=acc_random)
iter_total = session.run(self.iter_agr_total)
summary_writer.add_summary(summary, iter_total)
summary_writer.flush()
if step % self.logging_step == 0 or val_acc > best_val_acc:
logging.info(
'Agreement step %6d | Loss: %10.4f | val_acc: %.4f |'
'random_acc: %.4f | acc_train: %.4f | acc_train_cls_0: %.4f | '
'acc_train_cls_1: %.4f', step, loss_val, val_acc, acc_random,
acc_train, acc_0_train, acc_1_train)
if val_acc > best_val_acc:
best_val_acc = val_acc
if self.checkpoint_path:
self.saver.save(
session, self.checkpoint_path, write_meta_graph=False)
checkpoint_saved = True
# If we reached 100% accuracy, stop.
if best_val_acc >= 1.00:
logging.info('Reached 100% accuracy. Stopping...')
break
# Go for at least num_iter_after_best_val more iterations.
min_num_iter = max(self.min_num_iter,
step + self.num_iter_after_best_val)
logging.info(
'Achieved best validation. '
'Extending to at least %d iterations...', min_num_iter)
step += 1
has_converged, iter_below_tol = self.check_convergence(
prev_loss_val,
loss_val,
step,
self.max_num_iter,
iter_below_tol,
min_num_iter=min_num_iter)
session.run(self.iter_agr_total_update)
prev_loss_val = loss_val
# Return to the best model.
if checkpoint_saved:
logging.info('Restoring best model...')
self.saver.restore(session, self.checkpoint_path)
return best_val_acc