in lingvo/executor.py [0:0]
def __init__(self, train_cfg, ps_params_dict, *args, **kwargs):
"""Construct an ExecutorTpu BaseRunner.
Args:
train_cfg: SingleTaskModelParams or MultiTaskModelParams
ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
if train_cfg is a SingleTaskModelParams, we expect only one entry.
*args: List args to pass through to BaseRunner.
**kwargs: keyword args to pass through to BaseRunner.
"""
if py_utils.IsEagerMode():
assert tf.executing_eagerly()
tf.logging.info(f'FLAGS.tf_master: {FLAGS.tf_master}')
# Connect to the TPU runtime.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):])
tf.config.experimental_connect_to_cluster(resolver)
super().__init__(train_cfg, *args, **kwargs)
data_parallelism = self._cluster.num_splits_per_client
assert data_parallelism
num_devices_per_split = self._cluster.num_devices_per_split
tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
data_parallelism, num_devices_per_split)
self.task_scheduler = None
self._checkpoint_dir = os.path.join(self._logdir, 'train')
self._variable_renaming_rules = []
self._ml_perf = None
# If this is a multi-task model, grab the params for the TaskScheduler.
if issubclass(train_cfg.cls, base_model.SingleTaskModel):
tf.logging.info('single_task_model')
assert len(ps_params_dict) == 1
self._model_task_name = list(ps_params_dict.keys())[0]
self._single_task_mode = True
elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
tf.logging.info('multi_task_model')
if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel):
self._variable_renaming_rules = train_cfg.variable_renaming_rules
if train_cfg.task_schedule is None:
task_schedule_params = task_scheduler.ConstantScheduler.Params()
task_schedule_params.task_probs = sorted(
list(train_cfg.task_probs.IterParams()))
else:
task_schedule_params = train_cfg.task_schedule
self.task_scheduler = task_schedule_params.Instantiate()
self._single_task_mode = False
else:
tf.logging.fatal(
'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
train_cfg.cls)
tf.logging.info('train_cfg.cls: %s', train_cfg.cls)
self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
'trainer_params.txt')
if self._ml_perf is not None:
self._ml_perf_log = True
mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name)
else:
self._ml_perf_log = False
train_cfg = self.params
@py_utils.RetryOnTransientTfError()
def _WaitTillInit(job=None):
"""Wait until the model is ready."""
try:
if py_utils.IsEagerMode():
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
else:
# tpu.initialize_system() is called with None as embedding_config, as
# embedding_config is not available yet. Later in _Loop, it is called
# with the correct embedding_config. Since it cannot be called twice
# in the same graph with different embedding_config, we use a
# dummy_graph here.
dummy_graph = tf.Graph()
with dummy_graph.as_default():
tpu_initialize_system_op = tf.tpu.initialize_system(
embedding_config=None, job=job)
with self._GetSession(graph=dummy_graph) as sess:
topology = sess.run(tpu_initialize_system_op)
if train_cfg.train.tpu_computation_shape is None:
computation_shape = py_utils.ComputationShape(num_devices_per_split,
topology)
else:
computation_shape = train_cfg.train.tpu_computation_shape
assert num_devices_per_split == np.prod(computation_shape)
if train_cfg.train.tpu_device_order_mode is None:
device_assignment = device_assignment_lib.device_assignment(
topology,
computation_shape=computation_shape,
num_replicas=data_parallelism)
else:
device_assignment = device_assignment_lib.device_assignment(
topology,
computation_shape=computation_shape,
num_replicas=data_parallelism,
device_order_mode=train_cfg.train.tpu_device_order_mode)
py_utils.SetTpuDeviceAssignment(device_assignment, job)
tf.logging.info('device_assignment.core_assignment: %s',
str(device_assignment.core_assignment))
tf.logging.info('device_assignment.topology.device_coordinates: %s',
str(device_assignment.topology.device_coordinates))
except py_utils.transient_tf_errors as e:
tf.logging.info('TPU initialization failed: %s', e)
raise
if self._ml_perf_log:
mlp_log.mlperf_print(key='init_start', value=None)
if len(self._cluster.all_worker_names) > 1:
for worker in self._cluster.all_worker_names:
_WaitTillInit(worker)
else:
_WaitTillInit(None)
shared_model = self._MaybeConstructSharedModel(train_cfg)
self._program_schedule_dict = {}
self._programs = []
self._ckpt_programs = []
self._checkpoint_to_load = None
with self._cluster:
# Create the ExponentialMovingAverage singleton shared by all programs, if
# applicable.
ema = py_utils.CreateEMAForModel(train_cfg, self._global_step_var)
for task_string, program_schedule_params in ps_params_dict.items():
program_schedule_params.logdir = self._logdir
program_schedule_params.num_splits_per_client = data_parallelism
program_schedule_params.task_name = task_string
# If the model was created above, we'll inject it here as a
# shared_model.
ps = program_schedule_params.Instantiate(
shared_model=shared_model,
trial=self._trial,
ema=ema,
tf_master=self._tf_master)
self._program_schedule_dict[task_string] = ps
tf.logging.info('program_schedule_params: %s',
program_schedule_params.ToText())
self._programs += ps.Programs()
if ps.train_program:
self._ckpt_programs.append(ps.train_program)
else:
self._ckpt_programs += ps.Programs()
if program_schedule_params.ml_perf.benchmark_name is not None:
self._ml_perf = program_schedule_params.ml_perf
if ('checkpoint_to_load' in program_schedule_params and
program_schedule_params.checkpoint_to_load):
if (self._checkpoint_to_load and
(self._checkpoint_to_load !=
program_schedule_params.checkpoint_to_load)):
raise ValueError(f'Multiple values found for checkpoint_to_load: '
f'{self._checkpoint_to_load}, '
f'{program_schedule_params.checkpoint_to_load}.')
self._checkpoint_to_load = program_schedule_params.checkpoint_to_load
tf.logging.info('num_programs: %d', len(self._programs))
# When running in a vizier trainer, the executor reports infeasiable runs
# in case of errors. The programs report metrics and normal completions.
for program in self._programs:
if program._should_report_metrics:
self._should_report_metrics = True
with self._cluster, tf.container(
self._container_id), contextlib.ExitStack() as stack:
if not py_utils.IsEagerMode():
stack.enter_context(self._graph.as_default())
stack.enter_context(tf.device(self._cluster.GetPlacer()))
if FLAGS.pdb_on_exception:
stack.enter_context(pdb_wrapper.catch_post_mortem())
with py_utils.VariableStore(), py_utils.VariableRenameScope(
self._variable_renaming_rules):
for program in self._programs:
program.BuildTpuSubgraph()
py_utils.ClearTpuSummaryTensors()
if not py_utils.IsEagerMode():
self._initialize_tables = tf.tables_initializer()
self._initialize_local_vars = tf.local_variables_initializer()
self._initialize_global_vars = tf.global_variables_initializer()
checkpointer_models = [
program.GetModel() for program in self._ckpt_programs
]
if py_utils.IsEagerMode():
if FLAGS.use_v2_checkpoints_in_eager:
self._checkpointer = checkpointer.EagerCheckpointerV2(
self._checkpoint_dir,
models=checkpointer_models,
init_op=None,
train_params=train_cfg.train,
save_only=False)
else:
self._checkpointer = checkpointer.EagerCheckpointerV1(
self._checkpoint_dir,
models=checkpointer_models,
init_op=None,
train_params=train_cfg.train,
save_only=False)
else:
self._checkpointer = checkpointer.Checkpointer(
self._checkpoint_dir,
models=checkpointer_models,
init_op=self._initialize_global_vars,
train_params=train_cfg.train,
save_only=False)
for program in self._programs:
program.SetStatusMessageFn(self._SetStatusMessage)
tpu_embedding_collection = (
tpu_embedding_layers.TpuEmbeddingCollection.Get())
self._load_ops = tpu_embedding_collection.load_ops
self._retrieve_ops = tpu_embedding_collection.retrieve_ops
self._tpu_embedding = tpu_embedding_collection.tpu_embedding