research/gam/gam/trainer/trainer_classification.py [733:770]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  def train(self, data, session=None, **kwargs):
    """Train the classification model on the provided dataset.

    Args:
      data: A CotrainDataset object.
      session: A TensorFlow session or None.
      **kwargs: Other keyword arguments.

    Returns:
      best_test_acc: A float representing the test accuracy at the iteration
        where the validation accuracy is maximum.
      best_val_acc: A float representing the best validation accuracy.
    """
    summary_writer = kwargs['summary_writer']
    logging.info('Training classifier...')

    if not self.is_initialized:
      self.is_initialized = True
    else:
      if self.weight_decay_update is not None:
        session.run(self.weight_decay_update)
        logging.info('New weight decay value:  %f',
                     session.run(self.weight_decay_var))
      # Reset the optimizer state (e.g., momentum).
      session.run(self.reset_optimizer)

    if not self.warm_start:
      # Re-initialize variables.
      initializers = [v.initializer for v in self.variables.values()]
      initializers.append(self.global_step.initializer)
      session.run(initializers)

    # Construct data iterator.
    logging.info('Training classifier with %d samples...', data.num_train())
    train_indices = data.get_indices_train()
    unlabeled_indices = data.get_indices_unlabeled()
    val_indices = data.get_indices_val()
    test_indices = data.get_indices_test()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



research/gam/gam/trainer/trainer_classification_gcn.py [745:782]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  def train(self, data, session=None, **kwargs):
    """Train the classification model on the provided dataset.

    Args:
      data: A CotrainDataset object.
      session: A TensorFlow session or None.
      **kwargs: Other keyword arguments.

    Returns:
      best_test_acc: A float representing the test accuracy at the iteration
        where the validation accuracy is maximum.
      best_val_acc: A float representing the best validation accuracy.
    """
    summary_writer = kwargs['summary_writer']
    logging.info('Training classifier...')

    if not self.is_initialized:
      self.is_initialized = True
    else:
      if self.weight_decay_update is not None:
        session.run(self.weight_decay_update)
        logging.info('New weight decay value:  %f',
                     session.run(self.weight_decay_var))
      # Reset the optimizer state (e.g., momentum).
      session.run(self.reset_optimizer)

    if not self.warm_start:
      # Re-initialize variables.
      initializers = [v.initializer for v in self.variables.values()]
      initializers.append(self.global_step.initializer)
      session.run(initializers)

    # Construct data iterator.
    logging.info('Training classifier with %d samples...', data.num_train())
    train_indices = data.get_indices_train()
    unlabeled_indices = data.get_indices_unlabeled()
    val_indices = data.get_indices_val()
    test_indices = data.get_indices_test()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



