def __init__()

in lingvo/executor.py [0:0]


  def __init__(self, train_cfg, ps_params_dict, *args, **kwargs):
    """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
        if train_cfg is a SingleTaskModelParams, we expect only one entry.
      *args: List args to pass through to BaseRunner.
      **kwargs: keyword args to pass through to BaseRunner.
    """
    if py_utils.IsEagerMode():
      assert tf.executing_eagerly()
      tf.logging.info(f'FLAGS.tf_master: {FLAGS.tf_master}')

      # Connect to the TPU runtime.
      resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
          FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):])
      tf.config.experimental_connect_to_cluster(resolver)

    super().__init__(train_cfg, *args, **kwargs)

    data_parallelism = self._cluster.num_splits_per_client
    assert data_parallelism
    num_devices_per_split = self._cluster.num_devices_per_split
    tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                    data_parallelism, num_devices_per_split)

    self.task_scheduler = None
    self._checkpoint_dir = os.path.join(self._logdir, 'train')

    self._variable_renaming_rules = []

    self._ml_perf = None

    # If this is a multi-task model, grab the params for the TaskScheduler.
    if issubclass(train_cfg.cls, base_model.SingleTaskModel):
      tf.logging.info('single_task_model')
      assert len(ps_params_dict) == 1
      self._model_task_name = list(ps_params_dict.keys())[0]
      self._single_task_mode = True
    elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
      tf.logging.info('multi_task_model')

      if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel):
        self._variable_renaming_rules = train_cfg.variable_renaming_rules

      if train_cfg.task_schedule is None:
        task_schedule_params = task_scheduler.ConstantScheduler.Params()
        task_schedule_params.task_probs = sorted(
            list(train_cfg.task_probs.IterParams()))
      else:
        task_schedule_params = train_cfg.task_schedule
      self.task_scheduler = task_schedule_params.Instantiate()
      self._single_task_mode = False
    else:
      tf.logging.fatal(
          'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
          train_cfg.cls)

    tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

    self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                     'trainer_params.txt')
    if self._ml_perf is not None:
      self._ml_perf_log = True
      mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name)
    else:
      self._ml_perf_log = False

    train_cfg = self.params

    @py_utils.RetryOnTransientTfError()
    def _WaitTillInit(job=None):
      """Wait until the model is ready."""
      try:
        if py_utils.IsEagerMode():
          topology = tf.tpu.experimental.initialize_tpu_system(resolver)
        else:
          # tpu.initialize_system() is called with None as embedding_config, as
          # embedding_config is not available yet. Later in _Loop, it is called
          # with the correct embedding_config. Since it cannot be called twice
          # in the same graph with different embedding_config, we use a
          # dummy_graph here.
          dummy_graph = tf.Graph()
          with dummy_graph.as_default():
            tpu_initialize_system_op = tf.tpu.initialize_system(
                embedding_config=None, job=job)

          with self._GetSession(graph=dummy_graph) as sess:
            topology = sess.run(tpu_initialize_system_op)

        if train_cfg.train.tpu_computation_shape is None:
          computation_shape = py_utils.ComputationShape(num_devices_per_split,
                                                        topology)
        else:
          computation_shape = train_cfg.train.tpu_computation_shape
          assert num_devices_per_split == np.prod(computation_shape)

        if train_cfg.train.tpu_device_order_mode is None:
          device_assignment = device_assignment_lib.device_assignment(
              topology,
              computation_shape=computation_shape,
              num_replicas=data_parallelism)
        else:
          device_assignment = device_assignment_lib.device_assignment(
              topology,
              computation_shape=computation_shape,
              num_replicas=data_parallelism,
              device_order_mode=train_cfg.train.tpu_device_order_mode)
        py_utils.SetTpuDeviceAssignment(device_assignment, job)
        tf.logging.info('device_assignment.core_assignment: %s',
                        str(device_assignment.core_assignment))
        tf.logging.info('device_assignment.topology.device_coordinates: %s',
                        str(device_assignment.topology.device_coordinates))
      except py_utils.transient_tf_errors as e:
        tf.logging.info('TPU initialization failed: %s', e)
        raise

    if self._ml_perf_log:
      mlp_log.mlperf_print(key='init_start', value=None)
    if len(self._cluster.all_worker_names) > 1:
      for worker in self._cluster.all_worker_names:
        _WaitTillInit(worker)
    else:
      _WaitTillInit(None)

    shared_model = self._MaybeConstructSharedModel(train_cfg)

    self._program_schedule_dict = {}
    self._programs = []
    self._ckpt_programs = []

    self._checkpoint_to_load = None
    with self._cluster:
      # Create the ExponentialMovingAverage singleton shared by all programs, if
      # applicable.
      ema = py_utils.CreateEMAForModel(train_cfg, self._global_step_var)
      for task_string, program_schedule_params in ps_params_dict.items():
        program_schedule_params.logdir = self._logdir
        program_schedule_params.num_splits_per_client = data_parallelism
        program_schedule_params.task_name = task_string
        # If the model was created above, we'll inject it here as a
        # shared_model.
        ps = program_schedule_params.Instantiate(
            shared_model=shared_model,
            trial=self._trial,
            ema=ema,
            tf_master=self._tf_master)
        self._program_schedule_dict[task_string] = ps
        tf.logging.info('program_schedule_params: %s',
                        program_schedule_params.ToText())
        self._programs += ps.Programs()
        if ps.train_program:
          self._ckpt_programs.append(ps.train_program)
        else:
          self._ckpt_programs += ps.Programs()
        if program_schedule_params.ml_perf.benchmark_name is not None:
          self._ml_perf = program_schedule_params.ml_perf
        if ('checkpoint_to_load' in program_schedule_params and
            program_schedule_params.checkpoint_to_load):
          if (self._checkpoint_to_load and
              (self._checkpoint_to_load !=
               program_schedule_params.checkpoint_to_load)):
            raise ValueError(f'Multiple values found for checkpoint_to_load: '
                             f'{self._checkpoint_to_load}, '
                             f'{program_schedule_params.checkpoint_to_load}.')
          self._checkpoint_to_load = program_schedule_params.checkpoint_to_load

    tf.logging.info('num_programs: %d', len(self._programs))

    # When running in a vizier trainer, the executor reports infeasiable runs
    # in case of errors. The programs report metrics and normal completions.
    for program in self._programs:
      if program._should_report_metrics:
        self._should_report_metrics = True

    with self._cluster, tf.container(
        self._container_id), contextlib.ExitStack() as stack:
      if not py_utils.IsEagerMode():
        stack.enter_context(self._graph.as_default())
        stack.enter_context(tf.device(self._cluster.GetPlacer()))
      if FLAGS.pdb_on_exception:
        stack.enter_context(pdb_wrapper.catch_post_mortem())
      with py_utils.VariableStore(), py_utils.VariableRenameScope(
          self._variable_renaming_rules):
        for program in self._programs:
          program.BuildTpuSubgraph()
          py_utils.ClearTpuSummaryTensors()

      if not py_utils.IsEagerMode():
        self._initialize_tables = tf.tables_initializer()
        self._initialize_local_vars = tf.local_variables_initializer()
        self._initialize_global_vars = tf.global_variables_initializer()

      checkpointer_models = [
          program.GetModel() for program in self._ckpt_programs
      ]

      if py_utils.IsEagerMode():
        if FLAGS.use_v2_checkpoints_in_eager:
          self._checkpointer = checkpointer.EagerCheckpointerV2(
              self._checkpoint_dir,
              models=checkpointer_models,
              init_op=None,
              train_params=train_cfg.train,
              save_only=False)
        else:
          self._checkpointer = checkpointer.EagerCheckpointerV1(
              self._checkpoint_dir,
              models=checkpointer_models,
              init_op=None,
              train_params=train_cfg.train,
              save_only=False)
      else:
        self._checkpointer = checkpointer.Checkpointer(
            self._checkpoint_dir,
            models=checkpointer_models,
            init_op=self._initialize_global_vars,
            train_params=train_cfg.train,
            save_only=False)

      for program in self._programs:
        program.SetStatusMessageFn(self._SetStatusMessage)

      tpu_embedding_collection = (
          tpu_embedding_layers.TpuEmbeddingCollection.Get())
      self._load_ops = tpu_embedding_collection.load_ops
      self._retrieve_ops = tpu_embedding_collection.retrieve_ops
      self._tpu_embedding = tpu_embedding_collection.tpu_embedding