def train()

in research/gam/gam/trainer/trainer_agreement.py [0:0]


  def train(self, data, session=None, **kwargs):
    """Train an agreement model."""

    summary_writer = kwargs['summary_writer']
    logging.info('Training agreement model...')

    if not self.is_initialized:
      self.is_initialized = True
    else:
      if self.weight_decay_update is not None:
        session.run(self.weight_decay_update)
        logging.info('New weight decay value:  %f',
                     session.run(self.weight_decay_var))

    # Construct data iterator.
    if self.use_graph:
      edges_train, agreement_train, edges_val, agreement_val = \
        self._get_neighbors(data)
      num_samples_train = agreement_train.shape[0]
      num_samples_val = agreement_val.shape[0]
    else:
      labeled_samples = data.get_indices_train()
      num_labeled_samples = len(labeled_samples)
      num_samples_train = num_labeled_samples * num_labeled_samples
      num_samples_val = min(
          int(num_samples_train * self.ratio_val), self.max_num_samples_val)

    if num_samples_train == 0:
      logging.info('No samples to train agreement. Skipping...')
      return None

    if not self.warm_start:
      # Re-initialize variables.
      initializers = [v.initializer for v in self.trainable_vars]
      initializers.append(self.global_step.initializer)
      session.run(initializers)
      # Reset the optimizer state (e.g., momentum).
      session.run(self.reset_optimizer)

    logging.info(
        'Training agreement with %d samples and validation on %d samples.',
        num_samples_train, num_samples_val)

    # Create an iterator over training data pairs.
    if self.use_graph:
      # If we use the graph, then the training data consists of graph edges
      # and the agreement (1.0 or 0.0) between them.
      data_iterator_train = self._get_train_edge_iterator(
          edges_train,
          agreement_train,
          self.batch_size,
          data,
          add_negatives=self.add_negative_edges)
    else:
      # If we don't use the graph, then the training data consists of pairs of
      # labeled sampels, and the agreement (1.0 or 0.0) between them.

      # Compute ratio of positives to negative samples.
      labeled_samples_labels = data.get_labels(labeled_samples)
      ratio_pos_to_neg = self._compute_ratio_pos_neg(labeled_samples_labels)

      # Split data into train and validation.
      labeled_samples_train, labeled_nodes_val = self._select_val_samples(
          labeled_samples, self.ratio_val)

      # Create an iterator over training data pairs.
      data_iterator_train = self._pair_iterator(
          labeled_samples_train, data, ratio_pos_neg=ratio_pos_to_neg)

    # Start training.
    best_val_acc = -1
    checkpoint_saved = False
    step = 0
    iter_below_tol = 0
    min_num_iter = self.min_num_iter
    has_converged = step >= self.max_num_iter
    if not has_converged:
      self.num_iter_trained += 1
    prev_loss_val = np.inf
    while not has_converged:
      feed_dict = self._construct_feed_dict(data_iterator_train, is_train=True)

      if self.enable_summaries and step % self.summary_step == 0:
        loss_val, summary, iter_total, _ = session.run(
            [self.loss_op, self.summary_op, self.iter_agr_total, self.train_op],
            feed_dict=feed_dict)
        summary_writer.add_summary(summary, iter_total)
        summary_writer.flush()
      else:
        loss_val, _ = session.run((self.loss_op, self.train_op),
                                  feed_dict=feed_dict)

      # Log the loss, if necessary.
      if step % self.logging_step == 0:
        logging.info('Agreement step %6d | Loss: %10.4f', step, loss_val)

      # Run validation, if necessary.
      if step % self.eval_step == 0:
        if num_samples_val == 0:
          logging.info('Skipping validation. No validation samples available.')
          break

        # Evaluate on the selected validation data.
        if self.use_graph:
          data_iterator_val = batch_iterator(
              edges_val,
              agreement_val,
              batch_size=self.batch_size,
              shuffle=False,
              allow_smaller_batch=True,
              repeat=False)
        else:
          data_iterator_val = self._pair_iterator(
              labeled_nodes_val, data, ratio_pos_neg=ratio_pos_to_neg)
        val_acc = self._eval_validation(data_iterator_val, num_samples_val,
                                        session)

        # Evaluate over a random choice of sample pairs, either labeled or not.
        acc_random = self._eval_random_pairs(data, session)

        # Evaluate the accuracy on the latest train batch. We track this to make
        # sure the agreement model is able to fit the training data, but can be
        # eliminated if efficiency is an issue.
        acc_train, acc_0_train, acc_1_train = self._eval_train(
            session, feed_dict)

        if self.enable_summaries:
          summary = tf.Summary()
          summary.value.add(
              tag='AgreementModel/train_acc', simple_value=acc_train)
          summary.value.add(tag='AgreementModel/val_acc', simple_value=val_acc)
          if acc_random is not None:
            summary.value.add(
                tag='AgreementModel/random_acc', simple_value=acc_random)
          iter_total = session.run(self.iter_agr_total)
          summary_writer.add_summary(summary, iter_total)
          summary_writer.flush()
        if step % self.logging_step == 0 or val_acc > best_val_acc:
          logging.info(
              'Agreement step %6d | Loss: %10.4f | val_acc: %.4f |'
              'random_acc: %.4f | acc_train: %.4f | acc_train_cls_0: %.4f | '
              'acc_train_cls_1: %.4f', step, loss_val, val_acc, acc_random,
              acc_train, acc_0_train, acc_1_train)
        if val_acc > best_val_acc:
          best_val_acc = val_acc
          if self.checkpoint_path:
            self.saver.save(
                session, self.checkpoint_path, write_meta_graph=False)
            checkpoint_saved = True
          # If we reached 100% accuracy, stop.
          if best_val_acc >= 1.00:
            logging.info('Reached 100% accuracy. Stopping...')
            break
          # Go for at least num_iter_after_best_val more iterations.
          min_num_iter = max(self.min_num_iter,
                             step + self.num_iter_after_best_val)
          logging.info(
              'Achieved best validation. '
              'Extending to at least %d iterations...', min_num_iter)

      step += 1
      has_converged, iter_below_tol = self.check_convergence(
          prev_loss_val,
          loss_val,
          step,
          self.max_num_iter,
          iter_below_tol,
          min_num_iter=min_num_iter)
      session.run(self.iter_agr_total_update)
      prev_loss_val = loss_val

    # Return to the best model.
    if checkpoint_saved:
      logging.info('Restoring best model...')
      self.saver.restore(session, self.checkpoint_path)

    return best_val_acc