def __init__()

in research/gam/gam/trainer/trainer_classification_gcn.py [0:0]


  def __init__(self,
               model,
               data,
               trainer_agr,
               optimizer,
               lr_initial,
               batch_size,
               min_num_iter,
               max_num_iter,
               num_iter_after_best_val,
               max_num_iter_cotrain,
               reg_weight_ll,
               reg_weight_lu,
               reg_weight_uu,
               num_pairs_reg,
               iter_cotrain,
               reg_weight_vat=0.0,
               use_ent_min=False,
               enable_summaries=False,
               summary_step=1,
               summary_dir=None,
               warm_start=False,
               gradient_clip=None,
               logging_step=1,
               eval_step=1,
               abs_loss_chg_tol=1e-10,
               rel_loss_chg_tol=1e-7,
               loss_chg_iter_below_tol=30,
               checkpoints_dir=None,
               weight_decay=None,
               weight_decay_schedule=None,
               penalize_neg_agr=False,
               first_iter_original=True,
               use_l2_classif=True,
               seed=None,
               lr_decay_steps=None,
               lr_decay_rate=None,
               use_graph=False):
    super(TrainerClassificationGCN, self).__init__(
        model=model,
        abs_loss_chg_tol=abs_loss_chg_tol,
        rel_loss_chg_tol=rel_loss_chg_tol,
        loss_chg_iter_below_tol=loss_chg_iter_below_tol)
    self.data = data
    self.trainer_agr = trainer_agr
    self.batch_size = batch_size
    self.min_num_iter = min_num_iter
    self.max_num_iter = max_num_iter
    self.num_iter_after_best_val = num_iter_after_best_val
    self.max_num_iter_cotrain = max_num_iter_cotrain
    self.enable_summaries = enable_summaries
    self.summary_step = summary_step
    self.summary_dir = summary_dir
    self.warm_start = warm_start
    self.gradient_clip = gradient_clip
    self.logging_step = logging_step
    self.eval_step = eval_step
    self.checkpoint_path = (
        os.path.join(checkpoints_dir, 'classif_best.ckpt')
        if checkpoints_dir is not None else None)
    self.weight_decay_initial = weight_decay
    self.weight_decay_schedule = weight_decay_schedule
    self.num_pairs_reg = num_pairs_reg
    self.reg_weight_ll = reg_weight_ll
    self.reg_weight_lu = reg_weight_lu
    self.reg_weight_uu = reg_weight_uu
    self.reg_weight_vat = reg_weight_vat
    self.use_ent_min = use_ent_min
    self.penalize_neg_agr = penalize_neg_agr
    self.use_l2_classif = use_l2_classif
    self.first_iter_original = first_iter_original
    self.iter_cotrain = iter_cotrain
    self.lr_initial = lr_initial
    self.lr_decay_steps = lr_decay_steps
    self.lr_decay_rate = lr_decay_rate
    self.use_graph = use_graph

    # Build TensorFlow graph.
    logging.info('Building classification TensorFlow graph...')

    # Create placeholders.
    input_indices = tf.placeholder(
        tf.int64, shape=(None,), name='input_indices')
    input_indices_unlabeled = tf.placeholder(
        tf.int32, shape=(None,), name='input_indices_unlabeled')
    input_labels = tf.placeholder(tf.int64, shape=(None,), name='input_labels')

    # Create a placeholder specifying if this is train time.
    is_train = tf.placeholder_with_default(False, shape=[], name='is_train')

    # Create some placeholders specific to GCN.
    self.support_op = tf.sparse_placeholder(tf.float32, name='support')
    self.features_op = tf.sparse_placeholder(tf.float32, name='features')
    self.num_features_nonzero_op = tf.placeholder(
        tf.int32, name='num_features_nonzero')

    # Save the data required to fill in these placeholders. We don't add them
    # directly in the graph as constants in order to avoid saving large
    # checkpoints.
    self.support = data.support
    self.features = data.dataset.features_sparse
    self.num_features_nonzero = data.num_features_nonzero

    # Create variables and predictions.
    with tf.variable_scope('predictions'):
      encoding, variables_enc, reg_params_enc = (
          self.model.get_encoding_and_params(
              inputs=self.features_op,
              is_train=is_train,
              support=self.support_op,
              num_features_nonzero=self.num_features_nonzero_op))
      self.variables = variables_enc
      self.reg_params = reg_params_enc
      predictions, variables_pred, reg_params_pred = (
          self.model.get_predictions_and_params(
              encoding=encoding,
              is_train=is_train,
              support=self.support_op,
              num_features_nonzero=self.num_features_nonzero_op))
      self.variables.update(variables_pred)
      self.reg_params.update(reg_params_pred)
      normalized_predictions = self.model.normalize_predictions(predictions)
      predictions_var_scope = tf.get_variable_scope()

      predictions_batch = tf.gather(predictions, input_indices, axis=0)
      normalized_predictions_batch = tf.gather(
          normalized_predictions, input_indices, axis=0)
      one_hot_labels = tf.one_hot(
          input_labels, data.num_classes, name='targets_one_hot')

    # Create a variable for weight decay that may be updated.
    weight_decay_var, weight_decay_update = self._create_weight_decay_var(
        weight_decay, weight_decay_schedule)

    # Create counter for classification iterations.
    iter_cls_total, iter_cls_total_update = self._create_counter()

    # Create loss.
    with tf.name_scope('loss'):
      if self.use_l2_classif:
        loss_supervised = tf.square(one_hot_labels -
                                    normalized_predictions_batch)
        loss_supervised = tf.reduce_sum(loss_supervised, axis=-1)
        loss_supervised = tf.reduce_mean(loss_supervised)
      else:
        loss_supervised = self.model.get_loss(
            predictions=predictions_batch,
            targets=one_hot_labels,
            weight_decay=None)

      # Agreement regularization loss.
      loss_agr = self._get_agreement_reg_loss(data, is_train)
      # If the first co-train iteration trains the original model (for
      # comparison purposes), then we do not add an agreement loss.
      if self.first_iter_original:
        loss_agr_weight = tf.cast(tf.greater(iter_cotrain, 0), tf.float32)
        loss_agr = loss_agr * loss_agr_weight

      # Weight decay loss.
      loss_reg = 0.0
      if weight_decay_var is not None:
        for var in self.reg_params.values():
          loss_reg += weight_decay_var * tf.nn.l2_loss(var)

      # Adversarial loss, in case we want to add VAT on top of GAM.
      ones = tf.fill(tf.shape(input_indices_unlabeled), 1.0)
      unlabeled_mask = tf.scatter_nd(
          input_indices_unlabeled[:, None],
          updates=ones,
          shape=[
              data.num_samples,
          ],
          name='unlabeled_mask')
      placeholders = {
          'support': self.support_op,
          'num_features_nonzero': self.num_features_nonzero_op
      }
      loss_vat = get_loss_vat(
          inputs=self.features_op,
          predictions=predictions,
          mask=unlabeled_mask,
          is_train=is_train,
          model=model,
          placeholders=placeholders,
          predictions_var_scope=predictions_var_scope)
      num_unlabeled = tf.shape(input_indices_unlabeled)[0]
      loss_vat = tf.cond(
          tf.greater(num_unlabeled, 0), lambda: loss_vat, lambda: 0.0)
      if self.use_ent_min:
        # Use entropy minimization with VAT (i.e. VATENT).
        loss_ent = entropy_y_x(predictions, unlabeled_mask)
        loss_vat = loss_vat + tf.cond(
            tf.greater(num_unlabeled, 0), lambda: loss_ent, lambda: 0.0)
      loss_vat = loss_vat * self.reg_weight_vat
      if self.first_iter_original:
        # Do not add the adversarial loss in the first iteration if
        # the first iteration trains the plain baseline model.
        weight_loss_vat = tf.cond(
            tf.greater(iter_cotrain, 0), lambda: 1.0, lambda: 0.0)
        loss_vat = loss_vat * weight_loss_vat

      # Total loss.
      loss_op = loss_supervised + loss_agr + loss_reg + loss_vat

    # Create accuracy.
    accuracy = tf.equal(
        tf.argmax(normalized_predictions_batch, 1), input_labels)
    accuracy = tf.reduce_mean(tf.cast(accuracy, tf.float32))

    # Create Tensorboard summaries.
    if self.enable_summaries:
      summaries = [
          tf.summary.scalar('loss_supervised', loss_supervised),
          tf.summary.scalar('loss_agr', loss_agr),
          tf.summary.scalar('loss_reg', loss_reg),
          tf.summary.scalar('loss_total', loss_op)
      ]
      self.summary_op = tf.summary.merge(summaries)

    # Create learning rate schedule and optimizer.
    self.global_step = tf.train.get_or_create_global_step()
    if self.lr_decay_steps is not None and self.lr_decay_rate is not None:
      self.lr = tf.train.exponential_decay(
          self.lr_initial,
          self.global_step,
          self.lr_decay_steps,
          self.lr_decay_rate,
          staircase=True)
      self.optimizer = optimizer(self.lr)
    else:
      self.optimizer = optimizer(lr_initial)

    # Get trainable variables and compute gradients.
    grads_and_vars = self.optimizer.compute_gradients(
        loss_op,
        tf.trainable_variables(scope=tf.get_default_graph().get_name_scope()))
    # Clip gradients.
    if self.gradient_clip:
      variab = [elem[1] for elem in grads_and_vars]
      gradients = [elem[0] for elem in grads_and_vars]
      gradients, _ = tf.clip_by_global_norm(gradients, self.gradient_clip)
      grads_and_vars = tuple(zip(gradients, variab))
    with tf.control_dependencies(
        tf.get_collection(
            tf.GraphKeys.UPDATE_OPS,
            scope=tf.get_default_graph().get_name_scope())):
      train_op = self.optimizer.apply_gradients(
          grads_and_vars, global_step=self.global_step)

    # Create a saver for model variables.
    trainable_vars = [v for _, v in grads_and_vars]

    # Put together the subset of variables to save and restore from the best
    # validation accuracy as we train the agreement model in one cotrain round.
    vars_to_save = trainable_vars + []
    if isinstance(weight_decay_var, tf.Variable):
      vars_to_save.append(weight_decay_var)
    saver = tf.train.Saver(vars_to_save)

    # Put together all variables that need to be saved in case the process is
    # interrupted and needs to be restarted.
    self.vars_to_save = [iter_cls_total, self.global_step]
    if isinstance(weight_decay_var, tf.Variable):
      self.vars_to_save.append(weight_decay_var)
    if self.warm_start:
      self.vars_to_save.extend([v for v in self.variables])

    # More variables to be initialized after the session is created.
    self.is_initialized = False

    self.rng = np.random.RandomState(seed)
    self.input_indices = input_indices
    self.input_indices_unlabeled = input_indices_unlabeled
    self.input_labels = input_labels
    self.predictions = predictions
    self.normalized_predictions = normalized_predictions
    self.normalized_predictions_batch = normalized_predictions_batch
    self.weight_decay_var = weight_decay_var
    self.weight_decay_update = weight_decay_update
    self.iter_cls_total = iter_cls_total
    self.iter_cls_total_update = iter_cls_total_update
    self.accuracy = accuracy
    self.train_op = train_op
    self.loss_op = loss_op
    self.saver = saver
    self.batch_size_actual = tf.shape(self.predictions)[0]
    self.reset_optimizer = tf.variables_initializer(self.optimizer.variables())
    self.is_train = is_train