def __init__()

in lingvo/core/base_model.py [0:0]


  def __init__(self, params):
    tp = params.train
    if tp and tp.init_from_checkpoint_override is not None:
      assert len(tp.init_from_checkpoint_rules) == 1
      rules = list(tp.init_from_checkpoint_rules.values())[0]
      tp.init_from_checkpoint_rules.clear()
      tp.init_from_checkpoint_rules[tp.init_from_checkpoint_override] = rules
    assert issubclass(params.cls, BaseTask)
    # Ensure global_step exists before calling super.
    py_utils.GetOrCreateGlobalStepVar()
    super().__init__(params)

    p = self.params

    self._encoder = None
    self._online_encoder = None
    self._decoder = None

    self._loss = None
    self._train_op = None
    self._post_train_ops = []
    self._eval_metrics = {}
    self._per_example = {}

    # Create the gradient mask,
    self._per_input_gradient_mask = None

    if p.task_global_step:
      with tf.name_scope(None), tf.variable_scope(
          py_utils.GetGlobalVariableScope()):
        var_name = p.name + '_global_step'
        self.CreateVariable(
            var_name,
            var_params=py_utils.WeightParams([],
                                             py_utils.WeightInit.Constant(0),
                                             tf.int64),
            trainable=False,
            collections=[tf.GraphKeys.GLOBAL_VARIABLES])
        summary_utils.scalar(var_name, self._private_vars[var_name])
        self._global_step_var = self._private_vars[var_name]
    else:
      self._global_step_var = py_utils.GetOrCreateGlobalStepVar()

    with py_utils.GlobalStepContext(self._global_step_var):
      if p.input:
        # TODO(zhifengc): Consider a simpler way to ensure the input
        # generator stops after one epoch.
        if self.do_eval and p.eval:
          seq_inp = issubclass(p.input.cls,
                               base_input_generator.BaseInputGeneratorFromFiles)
          if p.input.num_samples > 0:
            if (p.eval.samples_per_summary
                == 0) or (p.input.num_samples < p.eval.samples_per_summary):
              p.eval.samples_per_summary = p.input.num_samples
              # If we know the dataset size and we want to evaluate the full
              # set, we need to coordinate the input generator to flush out
              # all samples so the evaler and decoder compute metrics on the
              # whole set for each summary step.
              if seq_inp:
                p.input.flush_every_n = p.input.num_samples
            if p.eval.decoder_samples_per_summary is not None and (
                p.eval.decoder_samples_per_summary > p.input.num_samples):
              p.eval.decoder_samples_per_summary = p.input.num_samples
          if p.input.eval_samples_per_summary is not None:
            p.eval.samples_per_summary = p.input.eval_samples_per_summary
          if p.input.decoder_samples_per_summary is not None:
            p.eval.decoder_samples_per_summary = (
                p.input.decoder_samples_per_summary)
          if p.input.num_samples == 0 and not p.input.resettable:
            # Dataset size is unknown. Computes eval summary based on
            # num_samples.
            # We require static dataset size for non-resettable inputs.
            # Ignore if the dataset is repeated.
            repeated = (
                getattr(p.input, 'repeat_steps', None) or
                getattr(p.input, 'repeat_with_sentinel', False))
            if not repeated:
              assert p.eval.samples_per_summary > 0
          if seq_inp and p.input.num_batcher_threads > 1:
            tf.logging.warning(
                'input.num_batcher_threads > 1 inside eval mode.  '
                'The input generator may not iterate over exactly '
                'one epoch per run')
        input_params = input_policy.Apply(p.input)
        tf.logging.info('input_params: %s', input_params)
        self.CreateChild('input', input_params)

      tp = p.train

      # p.train can be None if this task is the teacher/student task in a
      # DistillationTask.
      if tp:
        self._SetLearnerFromLegacyParams(tp)
        if tp.learner is not None:
          if isinstance(tp.learner, (list, tuple)):
            self.CreateChildren('learners', tp.learner)
          else:
            self.CreateChildren('learners', [tp.learner])
      self._UpdateVnConfig()

      if (tp and tp.pruning_hparams_dict and
          pruning_utils.UsePruningInterface(tp.pruning_hparams_dict)):
        pruning_utils.PruningOp.Setup(tp.pruning_hparams_dict, self.global_step)

    # The set of ids of TF graphs in which ApplyExponentialMovingAverage has
    # been called.
    self._graphs_applied_ema = set()