def train()

in research/gam/gam/trainer/trainer_classification.py [0:0]


  def train(self, data, session=None, **kwargs):
    """Train the classification model on the provided dataset.

    Args:
      data: A CotrainDataset object.
      session: A TensorFlow session or None.
      **kwargs: Other keyword arguments.

    Returns:
      best_test_acc: A float representing the test accuracy at the iteration
        where the validation accuracy is maximum.
      best_val_acc: A float representing the best validation accuracy.
    """
    summary_writer = kwargs['summary_writer']
    logging.info('Training classifier...')

    if not self.is_initialized:
      self.is_initialized = True
    else:
      if self.weight_decay_update is not None:
        session.run(self.weight_decay_update)
        logging.info('New weight decay value:  %f',
                     session.run(self.weight_decay_var))
      # Reset the optimizer state (e.g., momentum).
      session.run(self.reset_optimizer)

    if not self.warm_start:
      # Re-initialize variables.
      initializers = [v.initializer for v in self.variables.values()]
      initializers.append(self.global_step.initializer)
      session.run(initializers)

    # Construct data iterator.
    logging.info('Training classifier with %d samples...', data.num_train())
    train_indices = data.get_indices_train()
    unlabeled_indices = data.get_indices_unlabeled()
    val_indices = data.get_indices_val()
    test_indices = data.get_indices_test()
    # Create an iterator for labeled samples for the supervised term.
    data_iterator_train = batch_iterator(
        train_indices,
        batch_size=self.batch_size,
        shuffle=True,
        allow_smaller_batch=False,
        repeat=True)
    # Create an iterator for unlabeled samples for the VAT loss term.
    data_iterator_unlabeled = batch_iterator(
        unlabeled_indices,
        batch_size=self.batch_size,
        shuffle=True,
        allow_smaller_batch=False,
        repeat=True)
    # Create iterators for ll, lu, uu pairs of samples for the agreement term.
    if self.use_graph:
      pair_ll_iterator = self.edge_iterator(
          data, batch_size=self.num_pairs_reg, labeling='ll')
      pair_lu_iterator = self.edge_iterator(
          data, batch_size=self.num_pairs_reg, labeling='lu')
      pair_uu_iterator = self.edge_iterator(
          data, batch_size=self.num_pairs_reg, labeling='uu')
    else:
      pair_ll_iterator = self.pair_iterator(train_indices, train_indices,
                                            self.num_pairs_reg, data)
      pair_lu_iterator = self.pair_iterator(train_indices, unlabeled_indices,
                                            self.num_pairs_reg, data)
      pair_uu_iterator = self.pair_iterator(unlabeled_indices,
                                            unlabeled_indices,
                                            self.num_pairs_reg, data)

    step = 0
    iter_below_tol = 0
    min_num_iter = self.min_num_iter
    has_converged = step >= self.max_num_iter
    prev_loss_val = np.inf
    best_test_acc = -1
    best_val_acc = -1
    checkpoint_saved = False
    while not has_converged:
      feed_dict = self._construct_feed_dict(
          data_iterator=data_iterator_train,
          split='train',
          pair_ll_iterator=pair_ll_iterator,
          pair_lu_iterator=pair_lu_iterator,
          pair_uu_iterator=pair_uu_iterator,
          data_iterator_unlabeled=data_iterator_unlabeled)
      if self.enable_summaries and step % self.summary_step == 0:
        loss_val, summary, iter_cls_total, _ = session.run(
            [self.loss_op, self.summary_op, self.iter_cls_total, self.train_op],
            feed_dict=feed_dict)
        summary_writer.add_summary(summary, iter_cls_total)
        summary_writer.flush()
      else:
        loss_val, _ = session.run((self.loss_op, self.train_op),
                                  feed_dict=feed_dict)

      # Log the loss, if necessary.
      if step % self.logging_step == 0:
        logging.info('Classification step %6d | Loss: %10.4f', step, loss_val)

      # Evaluate, if necessary.
      if step % self.eval_step == 0:
        val_acc = self._evaluate(val_indices, 'val', session, summary_writer)
        test_acc = self._evaluate(test_indices, 'test', session, summary_writer)

        if step % self.logging_step == 0 or val_acc > best_val_acc:
          logging.info(
              'Classification step %6d | Loss: %10.4f | val_acc: %10.4f | '
              'test_acc: %10.4f', step, loss_val, val_acc, test_acc)
        if val_acc > best_val_acc:
          best_val_acc = val_acc
          best_test_acc = test_acc
          if self.checkpoint_path:
            self.saver.save(
                session, self.checkpoint_path, write_meta_graph=False)
            checkpoint_saved = True
          # Go for at least num_iter_after_best_val more iterations.
          min_num_iter = max(self.min_num_iter,
                             step + self.num_iter_after_best_val)
          logging.info(
              'Achieved best validation. '
              'Extending to at least %d iterations...', min_num_iter)

      step += 1
      has_converged, iter_below_tol = self.check_convergence(
          prev_loss_val,
          loss_val,
          step,
          self.max_num_iter,
          iter_below_tol,
          min_num_iter=min_num_iter)
      session.run(self.iter_cls_total_update)
      prev_loss_val = loss_val

    # Return to the best model.
    if checkpoint_saved:
      logging.info('Restoring best model...')
      self.saver.restore(session, self.checkpoint_path)

    return best_test_acc, best_val_acc