in lingvo/core/program.py [0:0]
def UpdateProgramSchedule(ps_params,
dataset_list,
train_executions_per_eval,
train_steps_per_loop,
eval_steps_per_loop,
decode_steps_per_loop,
decode_summary_emails=None,
oneoff_checkpoint_to_load=None):
"""Update ProgramSchedule params with the given new configs.
Currently this override only support EvalProgram and DecodeProgram.
Args:
ps_params: SimpleProgramSchedule.Params(), to be overriden.
dataset_list: Optional[List[str]], if not None, it will override eval
datasets in ps_params.
train_executions_per_eval: Optional[int], if not None, it will override
train_executions_per_eval in ps_params.
train_steps_per_loop: Optional[int], if not None, it will override train
program's steps_per_loop.
eval_steps_per_loop: Optional[int], if not None, it will override all the
eval programs steps_per_loop. Currently list not supported.
decode_steps_per_loop: Optional[int], if not None, it will override all the
decode programs steps_per_loop. If set to -1, it will set
decode_until_out_of_range=True. Currently list not supported.
decode_summary_emails: List of emails to send Decode summary to.
oneoff_checkpoint_to_load: Optional[str], if not None, it will override
checkpoint_to_load.
Returns:
ps_params after overriden.
"""
assert ps_params
if issubclass(ps_params.cls, MultiTaskProgramSchedule):
tf.logging.info(
'UpdateProgramSchedule does not support MultiTaskProgramSchedule.')
return ps_params
if dataset_list is not None:
ps_params.dataset_names = dataset_list
# Dict for all the override datasets:
# - key: each dataset name
# - value: dict with keys ('eval_exist', 'decode_exist') and bool values,
# indicate whether the dataset already exist in current
# EvalProgram, DecodeProgram. If not, we will create them.
ds_dict = {}
for dataset in dataset_list:
ds_dict[dataset] = {'eval_exist': False, 'decode_exist': False}
eval_programs = []
default_eval_steps_per_loop = 0
default_decode_steps_per_loop = 0
for eval_program in ps_params.eval_programs:
if issubclass(eval_program.cls, EvalProgram):
default_eval_steps_per_loop = eval_program.steps_per_loop
elif issubclass(eval_program.cls, DecodeProgram):
default_decode_steps_per_loop = _GetDecodeStepsPerLoop(eval_program)
if eval_program.dataset_name in ds_dict:
eval_programs.append(eval_program)
if issubclass(eval_program.cls, EvalProgram):
ds_dict[eval_program.dataset_name]['eval_exist'] = True
elif issubclass(eval_program.cls, DecodeProgram):
ds_dict[eval_program.dataset_name]['decode_exist'] = True
for dataset_name, exists in ds_dict.items():
if not exists['eval_exist']:
eval_programs.append(
_CreateProgramParams(EvalProgram, 'eval_tpu', dataset_name,
default_eval_steps_per_loop))
if not exists['decode_exist']:
eval_programs.append(
_CreateProgramParams(DecodeProgram, 'decode_tpu', dataset_name,
default_decode_steps_per_loop))
ps_params.eval_programs = eval_programs
if train_executions_per_eval is not None:
ps_params.train_executions_per_eval = train_executions_per_eval
if train_steps_per_loop is not None:
ps_params.train_program.steps_per_loop = train_steps_per_loop
if eval_steps_per_loop is not None:
if eval_steps_per_loop == 0:
ps_params.eval_programs = _ClearSpecifiedProgram(ps_params.eval_programs,
EvalProgram)
else:
for eval_program in ps_params.eval_programs:
if issubclass(eval_program.cls, EvalProgram):
eval_program.steps_per_loop = eval_steps_per_loop
if decode_steps_per_loop is not None:
if decode_steps_per_loop == 0:
ps_params.eval_programs = _ClearSpecifiedProgram(ps_params.eval_programs,
DecodeProgram)
else:
for eval_program in ps_params.eval_programs:
if issubclass(eval_program.cls, DecodeProgram):
_SetDecodeStepsPerLoop(eval_program, decode_steps_per_loop)
if oneoff_checkpoint_to_load:
if ps_params.train_executions_per_eval:
tf.logging.warning(
'Training with decoding does not suggest to set `checkpoint_to_load` '
'for DecodeProgram!')
ps_params.checkpoint_to_load = oneoff_checkpoint_to_load
if decode_summary_emails:
for eval_program in ps_params.eval_programs:
if issubclass(eval_program.cls, DecodeProgram):
eval_program.emails = decode_summary_emails
return ps_params