def train()

in research/gam/gam/trainer/trainer_cotrain.py [0:0]


  def train(self, data, **kwargs):
    # Create a wrapper around the dataset, that also accounts for some
    # cotrain specific attributes and functions.
    data = CotrainDataset(
        data,
        keep_label_proportions=self.keep_label_proportions,
        inductive=self.inductive)

    if os.path.exists(self.data_dir) and self.load_from_checkpoint:
      # If this session is restored from a previous run, then we load the
      # self-labeled data from the last checkpoint.
      logging.info('Number of labeled samples before restoring: %d',
                   data.num_train())
      logging.info('Restoring self-labeled data from %s...', self.data_dir)
      data.restore_state_from_file(self.data_dir)
      logging.info('Number of labeled samples after restoring: %d',
                   data.num_train())

    # Build graph.
    logging.info('Building graph...')

    # Create a iteration counter.
    iter_cotrain, iter_cotrain_update = self._create_counter()

    if self.use_perfect_agr:
      # A perfect agreement model used for model.
      trainer_agr = TrainerPerfectAgreement(data=data)
    else:
      with tf.variable_scope('AgreementModel'):
        if self.always_agree:
          trainer_agr = TrainerAgreementAlwaysAgree(data=data)
        else:
          trainer_agr = TrainerAgreement(
              model=self.model_agr,
              data=data,
              optimizer=self.optimizer,
              gradient_clip=self.gradient_clip,
              min_num_iter=self.min_num_iter_agr,
              max_num_iter=self.max_num_iter_agr,
              num_iter_after_best_val=self.num_iter_after_best_val_agr,
              max_num_iter_cotrain=self.max_num_iter_cotrain,
              num_warm_up_iter=self.num_warm_up_iter_agr,
              warm_start=self.warm_start_agr,
              batch_size=self.batch_size_agr,
              enable_summaries=self.enable_summaries_per_model,
              summary_step=self.summary_step_agr,
              summary_dir=self.summary_dir,
              logging_step=self.logging_step_agr,
              eval_step=self.eval_step_agr,
              abs_loss_chg_tol=self.abs_loss_chg_tol,
              rel_loss_chg_tol=self.rel_loss_chg_tol,
              loss_chg_iter_below_tol=self.loss_chg_iter_below_tol,
              checkpoints_dir=self.checkpoints_dir,
              weight_decay=self.weight_decay_agr,
              weight_decay_schedule=self.weight_decay_schedule_agr,
              agree_by_default=False,
              percent_val=self.ratio_valid_agr,
              max_num_samples_val=self.max_samples_valid_agr,
              seed=self.seed,
              lr_decay_rate=self.lr_decay_rate_agr,
              lr_decay_steps=self.lr_decay_steps_agr,
              lr_initial=self.learning_rate_agr,
              use_graph=self.use_graph,
              add_negative_edges=self.add_negative_edges_agr)

    if self.use_perfect_cls:
      # A perfect classification model used for debugging purposes.
      trainer_cls = TrainerPerfectClassification(data=data)
    else:
      with tf.variable_scope('ClassificationModel'):
        trainer_cls_class = (
            TrainerClassificationGCN
            if isinstance(self.model_cls, GCN) else TrainerClassification)
        trainer_cls = trainer_cls_class(
            model=self.model_cls,
            data=data,
            trainer_agr=trainer_agr,
            optimizer=self.optimizer,
            gradient_clip=self.gradient_clip,
            batch_size=self.batch_size_cls,
            min_num_iter=self.min_num_iter_cls,
            max_num_iter=self.max_num_iter_cls,
            num_iter_after_best_val=self.num_iter_after_best_val_cls,
            max_num_iter_cotrain=self.max_num_iter_cotrain,
            reg_weight_ll=self.reg_weight_ll,
            reg_weight_lu=self.reg_weight_lu,
            reg_weight_uu=self.reg_weight_uu,
            num_pairs_reg=self.num_pairs_reg,
            reg_weight_vat=self.reg_weight_vat,
            use_ent_min=self.use_ent_min,
            enable_summaries=self.enable_summaries_per_model,
            summary_step=self.summary_step_cls,
            summary_dir=self.summary_dir,
            logging_step=self.logging_step_cls,
            eval_step=self.eval_step_cls,
            abs_loss_chg_tol=self.abs_loss_chg_tol,
            rel_loss_chg_tol=self.rel_loss_chg_tol,
            loss_chg_iter_below_tol=self.loss_chg_iter_below_tol,
            warm_start=self.warm_start_cls,
            checkpoints_dir=self.checkpoints_dir,
            weight_decay=self.weight_decay_cls,
            weight_decay_schedule=self.weight_decay_schedule_cls,
            penalize_neg_agr=self.penalize_neg_agr,
            use_l2_classif=self.use_l2_classif,
            first_iter_original=self.first_iter_original,
            seed=self.seed,
            iter_cotrain=iter_cotrain,
            lr_decay_rate=self.lr_decay_rate_cls,
            lr_decay_steps=self.lr_decay_steps_cls,
            lr_initial=self.learning_rate_cls,
            use_graph=self.use_graph)

    # Create a saver which saves only the variables that we would need to
    # restore in case the training process is restarted.
    vars_to_save = [iter_cotrain
                   ] + trainer_agr.vars_to_save + trainer_cls.vars_to_save
    saver = tf.train.Saver(vars_to_save)

    # Create a TensorFlow session. We allow soft placement in order to place
    # any supported ops on GPU. The allow_growth option lets our process
    # progressively use more gpu memory, per need basis, as opposed to
    # allocating it all from the beginning.
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)

    # Create a Tensorflow summary writer, shared by all models.
    summary_writer = tf.summary.FileWriter(self.summary_dir, session.graph)

    # Initialize the values of all variables and the train dataset iterator.
    session.run(tf.global_variables_initializer())

    # If a checkpoint with the variables already exists, we restore them.
    if self.checkpoints_dir:
      checkpts_path_cotrain = os.path.join(self.checkpoints_dir, 'cotrain.ckpt')
      if os.path.exists(checkpts_path_cotrain):
        if self.load_from_checkpoint:
          saver.restore(session, checkpts_path_cotrain)
      else:
        os.makedirs(checkpts_path_cotrain)
    else:
      checkpts_path_cotrain = None

    # Create a progress bar showing how many samples are labeled.
    pbar = tqdm(
        total=data.num_samples - data.num_train(), desc='self-labeled nodes')

    logging.info('Starting co-training...')
    step = session.run(iter_cotrain)
    stop = step >= self.max_num_iter_cotrain
    best_val_acc = -1
    test_acc_at_best = -1
    iter_at_best = -1
    while not stop:
      logging.info('----------------- Cotrain step %6d -----------------', step)
      # Train the agreement model.
      if self.first_iter_original and step == 0:
        logging.info('First iteration trains the original classifier.'
                     'No need to train the agreement model.')
        val_acc_agree = None
        acc_pred_by_agr = None
      else:
        val_acc_agree = trainer_agr.train(
            data, session=session, summary_writer=summary_writer)

        if self.eval_acc_pred_by_agr:
          # Evaluate the prediction accuracy by a majority vote model using the
          # agreement model.
          logging.info('Computing agreement majority vote predictions on '
                       'test data...')
          acc_pred_by_agr = trainer_agr.predict_label_by_agreement(
              session, data.get_indices_test(), self.num_neighbors_pred_by_agr)
        else:
          acc_pred_by_agr = None

      # Train classification model.
      test_acc, val_acc = trainer_cls.train(
          data, session=session, summary_writer=summary_writer)

      if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc_at_best = test_acc
        iter_at_best = step

      if self.enable_summaries:
        summary = tf.Summary()
        summary.value.add(tag='cotrain/test_acc', simple_value=test_acc)
        summary.value.add(tag='cotrain/val_acc', simple_value=val_acc)
        if val_acc_agree is not None:
          summary.value.add(
              tag='cotrain/val_acc_agree', simple_value=val_acc_agree)
        if acc_pred_by_agr is not None:
          summary.value.add(
              tag='cotrain/acc_predict_by_agreement',
              simple_value=acc_pred_by_agr)
        summary_writer.add_summary(summary, step)
        summary_writer.flush()

      logging.info(
          '--------- Cotrain step %6d | Accuracy val: %10.4f | '
          'Accuracy test: %10.4f ---------', step, val_acc, test_acc)
      logging.info(
          'Best validation acc: %.4f, corresponding test acc: %.4f at '
          'iteration %d', best_val_acc, test_acc_at_best, iter_at_best)
      if self.first_iter_original and step == 0:
        logging.info('No self-labeling because the first iteration trains the '
                     'original classifier for evaluation purposes.')
        step += 1
      else:
        # Extend labeled set by self-labeling.
        logging.info('Self-labeling...')
        selected_samples = self._extend_label_set(data, trainer_cls, session)

        # If no new data points are added to the training set, stop.
        num_new_labels = len(selected_samples)
        pbar.update(num_new_labels)
        if num_new_labels > 0:
          data.compute_dataset_statistics(selected_samples, summary_writer,
                                          step)
        else:
          logging.info('No new samples labeled. Stopping...')
          stop = True

        step += 1
        stop |= step >= self.max_num_iter_cotrain

        # Save model and dataset state in case of process preemption.
        if self.checkpoints_step and step % self.checkpoints_step == 0:
          self._save_state(saver, session, data, checkpts_path_cotrain)

      session.run(iter_cotrain_update)
      logging.info('________________________________________________________')

    logging.info(
        'Best validation acc: %.4f, corresponding test acc: %.4f at '
        'iteration %d', best_val_acc, test_acc_at_best, iter_at_best)
    pbar.close()