def __init__()

in research/gam/gam/trainer/trainer_agreement.py [0:0]


  def __init__(self,
               model,
               data,
               optimizer,
               lr_initial,
               min_num_iter,
               max_num_iter,
               num_iter_after_best_val,
               max_num_iter_cotrain,
               num_warm_up_iter,
               batch_size,
               gradient_clip=None,
               enable_summaries=False,
               summary_step=1,
               summary_dir=None,
               logging_step=1,
               eval_step=1,
               abs_loss_chg_tol=1e-10,
               rel_loss_chg_tol=1e-7,
               loss_chg_iter_below_tol=20,
               warm_start=False,
               checkpoints_dir=None,
               weight_decay=None,
               weight_decay_schedule=None,
               num_pairs_eval_random=1000,
               agree_by_default=False,
               percent_val=0.1,
               max_num_samples_val=10000,
               seed=None,
               lr_decay_steps=None,
               lr_decay_rate=None,
               use_graph=False,
               add_negative_edges=False):
    super(TrainerAgreement, self).__init__(
        model=model,
        abs_loss_chg_tol=abs_loss_chg_tol,
        rel_loss_chg_tol=rel_loss_chg_tol,
        loss_chg_iter_below_tol=loss_chg_iter_below_tol)
    self.data = data
    self.optimizer = optimizer
    self.min_num_iter = min_num_iter
    self.max_num_iter = max_num_iter
    self.num_iter_after_best_val = num_iter_after_best_val
    self.max_num_iter_cotrain = max_num_iter_cotrain
    self.num_warm_up_iter = num_warm_up_iter
    self.batch_size = batch_size
    self.gradient_clip = gradient_clip
    self.enable_summaries = enable_summaries
    self.summary_step = summary_step
    self.summary_dir = summary_dir
    self.checkpoints_dir = checkpoints_dir
    self.logging_step = logging_step
    self.eval_step = eval_step
    self.num_iter_trained = 0
    self.warm_start = warm_start
    self.checkpoint_path = (
        os.path.join(checkpoints_dir, 'agree_best.ckpt')
        if checkpoints_dir is not None else None)
    self.weight_decay = weight_decay
    self.weight_decay_schedule = weight_decay_schedule
    self.num_pairs_eval_random = num_pairs_eval_random
    self.agree_by_default = agree_by_default
    self.ratio_val = percent_val
    self.max_num_samples_val = max_num_samples_val
    self.original_var_scope = None
    self.lr_initial = lr_initial
    self.lr_decay_steps = lr_decay_steps
    self.lr_decay_rate = lr_decay_rate
    self.use_graph = use_graph
    self.add_negative_edges = add_negative_edges

    # Build TensorFlow graph.
    logging.info('Building TensorFlow agreement graph...')
    # The agreement model computes the label agreement between two samples.
    # We will refer to these samples as the src and tgt sample, using
    # graph terminology.

    # Create placeholders, and assign to these variables by default.
    features_shape = [None] + list(data.features_shape)
    src_features = tf.placeholder(
        tf.float32, shape=features_shape, name='src_features')
    tgt_features = tf.placeholder(
        tf.float32, shape=features_shape, name='tgt_features')
    # Create a placeholder for the agreement labels.
    labels = tf.placeholder(tf.float32, shape=(None,), name='labels')
    # Create a placeholder specifying if this is train time.
    is_train = tf.placeholder_with_default(False, shape=[], name='is_train')

    # Create variables and predictions.
    predictions, normalized_predictions, variables, reg_params = (
        self.create_agreement_prediction(src_features, tgt_features, is_train))

    # Create a variable for weight decay that may be updated later.
    weight_decay_var, weight_decay_update = self._create_weight_decay_var(
        weight_decay, weight_decay_schedule)

    # Create counter for the total number of agreement train iterations.
    iter_agr_total, iter_agr_total_update = self._create_counter()

    # Create loss.
    loss_op = self.model.get_loss(
        predictions=predictions,
        targets=labels,
        reg_params=reg_params,
        weight_decay=weight_decay_var)

    # Create accuracy.
    accuracy = accuracy_binary(normalized_predictions, labels)

    # Create optimizer.
    self.global_step = tf.train.get_or_create_global_step()
    if self.lr_decay_steps is not None and self.lr_decay_rate is not None:
      self.lr = tf.train.exponential_decay(
          self.lr_initial,
          self.global_step,
          self.lr_decay_steps,
          self.lr_decay_rate,
          staircase=True)
      self.optimizer = optimizer(self.lr)
    else:
      self.optimizer = optimizer(lr_initial)

    # Create train op.
    grads_and_vars = self.optimizer.compute_gradients(
        loss_op,
        tf.trainable_variables(scope=tf.get_default_graph().get_name_scope()))
    # Clip gradients.
    if self.gradient_clip:
      variab = [elem[1] for elem in grads_and_vars]
      gradients = [elem[0] for elem in grads_and_vars]
      gradients, _ = tf.clip_by_global_norm(gradients, self.gradient_clip)
      grads_and_vars = tuple(zip(gradients, variab))
    with tf.control_dependencies(
        tf.get_collection(
            tf.GraphKeys.UPDATE_OPS,
            scope=tf.get_default_graph().get_name_scope())):
      train_op = self.optimizer.apply_gradients(
          grads_and_vars, global_step=self.global_step)

    # Create Tensorboard summaries.
    if self.enable_summaries:
      summaries = [tf.summary.scalar('loss_agreement_inner', loss_op)]
      self.summary_op = tf.summary.merge(summaries)

    # Create a saver for the model trainable variables.
    self.trainable_vars = [v for _, v in grads_and_vars]

    # Put together the subset of variables to save and restore from the best
    # validation accuracy as we train the agreement model in one cotrain round.
    vars_to_save = self.trainable_vars + []
    if isinstance(weight_decay_var, tf.Variable):
      vars_to_save.append(weight_decay_var)
    self.saver = tf.train.Saver(vars_to_save)

    # Put together all variables that need to be saved in case the process is
    # interrupted and needs to be restarted.
    self.vars_to_save = [iter_agr_total]
    if isinstance(weight_decay_var, tf.Variable):
      self.vars_to_save.append(weight_decay_var)
    if self.warm_start:
      self.vars_to_save += self.trainable_vars

    # More variables to be initialized after the session is created.
    self.is_initialized = False

    self.rng = np.random.RandomState(seed)
    self.src_features = src_features
    self.tgt_features = tgt_features
    self.labels = labels
    self.predictions = predictions
    self.normalized_predictions = normalized_predictions
    self.variables = variables
    self.reg_params = reg_params
    self.weight_decay_var = weight_decay_var
    self.weight_decay_update = weight_decay_update
    self.iter_agr_total = iter_agr_total
    self.iter_agr_total_update = iter_agr_total_update
    self.accuracy = accuracy
    self.train_op = train_op
    self.loss_op = loss_op
    self.batch_size_actual = tf.shape(self.predictions)[0]
    self.reset_optimizer = tf.variables_initializer(self.optimizer.variables())
    self.is_train = is_train