in lingvo/core/base_model.py [0:0]
def __init__(self, params):
tp = params.train
if tp and tp.init_from_checkpoint_override is not None:
assert len(tp.init_from_checkpoint_rules) == 1
rules = list(tp.init_from_checkpoint_rules.values())[0]
tp.init_from_checkpoint_rules.clear()
tp.init_from_checkpoint_rules[tp.init_from_checkpoint_override] = rules
assert issubclass(params.cls, BaseTask)
# Ensure global_step exists before calling super.
py_utils.GetOrCreateGlobalStepVar()
super().__init__(params)
p = self.params
self._encoder = None
self._online_encoder = None
self._decoder = None
self._loss = None
self._train_op = None
self._post_train_ops = []
self._eval_metrics = {}
self._per_example = {}
# Create the gradient mask,
self._per_input_gradient_mask = None
if p.task_global_step:
with tf.name_scope(None), tf.variable_scope(
py_utils.GetGlobalVariableScope()):
var_name = p.name + '_global_step'
self.CreateVariable(
var_name,
var_params=py_utils.WeightParams([],
py_utils.WeightInit.Constant(0),
tf.int64),
trainable=False,
collections=[tf.GraphKeys.GLOBAL_VARIABLES])
summary_utils.scalar(var_name, self._private_vars[var_name])
self._global_step_var = self._private_vars[var_name]
else:
self._global_step_var = py_utils.GetOrCreateGlobalStepVar()
with py_utils.GlobalStepContext(self._global_step_var):
if p.input:
# TODO(zhifengc): Consider a simpler way to ensure the input
# generator stops after one epoch.
if self.do_eval and p.eval:
seq_inp = issubclass(p.input.cls,
base_input_generator.BaseInputGeneratorFromFiles)
if p.input.num_samples > 0:
if (p.eval.samples_per_summary
== 0) or (p.input.num_samples < p.eval.samples_per_summary):
p.eval.samples_per_summary = p.input.num_samples
# If we know the dataset size and we want to evaluate the full
# set, we need to coordinate the input generator to flush out
# all samples so the evaler and decoder compute metrics on the
# whole set for each summary step.
if seq_inp:
p.input.flush_every_n = p.input.num_samples
if p.eval.decoder_samples_per_summary is not None and (
p.eval.decoder_samples_per_summary > p.input.num_samples):
p.eval.decoder_samples_per_summary = p.input.num_samples
if p.input.eval_samples_per_summary is not None:
p.eval.samples_per_summary = p.input.eval_samples_per_summary
if p.input.decoder_samples_per_summary is not None:
p.eval.decoder_samples_per_summary = (
p.input.decoder_samples_per_summary)
if p.input.num_samples == 0 and not p.input.resettable:
# Dataset size is unknown. Computes eval summary based on
# num_samples.
# We require static dataset size for non-resettable inputs.
# Ignore if the dataset is repeated.
repeated = (
getattr(p.input, 'repeat_steps', None) or
getattr(p.input, 'repeat_with_sentinel', False))
if not repeated:
assert p.eval.samples_per_summary > 0
if seq_inp and p.input.num_batcher_threads > 1:
tf.logging.warning(
'input.num_batcher_threads > 1 inside eval mode. '
'The input generator may not iterate over exactly '
'one epoch per run')
input_params = input_policy.Apply(p.input)
tf.logging.info('input_params: %s', input_params)
self.CreateChild('input', input_params)
tp = p.train
# p.train can be None if this task is the teacher/student task in a
# DistillationTask.
if tp:
self._SetLearnerFromLegacyParams(tp)
if tp.learner is not None:
if isinstance(tp.learner, (list, tuple)):
self.CreateChildren('learners', tp.learner)
else:
self.CreateChildren('learners', [tp.learner])
self._UpdateVnConfig()
if (tp and tp.pruning_hparams_dict and
pruning_utils.UsePruningInterface(tp.pruning_hparams_dict)):
pruning_utils.PruningOp.Setup(tp.pruning_hparams_dict, self.global_step)
# The set of ids of TF graphs in which ApplyExponentialMovingAverage has
# been called.
self._graphs_applied_ema = set()