in community-content/vertex_model_garden/model_oss/tfvision/train_hpt_oss.py [0:0]
def main(_):
log_level = _LOG_LEVEL.value
if log_level and log_level in ['FATAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG']:
logging.set_verbosity(log_level)
params = parse_params()
logging.info('The actual training parameters are:\n%s', params.as_dict())
model_dir = os.path.join(
FLAGS.model_dir,
'trial_' + hypertune_utils.get_trial_id_from_environment(),
)
logging.info('model_dir in this trial is: %s', model_dir)
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu,
)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir,
)
train_utils.save_gin_config(FLAGS.mode, model_dir)
eval_metric_name = get_best_eval_metric(_OBJECTIVE.value, params)
eval_filepath = os.path.join(
model_dir, constants.BEST_CKPT_DIRNAME, constants.BEST_CKPT_EVAL_FILENAME
)
logging.info('Load eval metrics from: %s.', eval_filepath)
wait_for_evaluation_file(eval_filepath, _MAX_EVAL_WAIT_TIME.value)
with tf.io.gfile.GFile(eval_filepath, 'rb') as f:
eval_metric_results = json.load(f)
logging.info('eval metrics are: %s.', eval_metric_results)
if (
eval_metric_name in eval_metric_results
and constants.BEST_CKPT_STEP_NAME in eval_metric_results
):
hp_metric = eval_metric_results[eval_metric_name]
hp_step = int(eval_metric_results[constants.BEST_CKPT_STEP_NAME])
hpt = hypertune.HyperTune()
hpt.report_hyperparameter_tuning_metric(
hyperparameter_metric_tag=constants.HP_METRIC_TAG,
metric_value=hp_metric,
global_step=hp_step,
)
logging.info(
'Send HP metric: %f and steps %d to hyperparameter tuning.',
hp_metric,
hp_step,
)
else:
logging.info(
'Either %s or %s is not included in the evaluation results: %s.',
eval_metric_name,
constants.BEST_CKPT_STEP_NAME,
eval_metric_results,
)