def main()

in community-content/vertex_model_garden/model_oss/tfvision/train_hpt_oss.py [0:0]


def main(_):
  log_level = _LOG_LEVEL.value
  if log_level and log_level in ['FATAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG']:
    logging.set_verbosity(log_level)
  params = parse_params()
  logging.info('The actual training parameters are:\n%s', params.as_dict())
  model_dir = os.path.join(
      FLAGS.model_dir,
      'trial_' + hypertune_utils.get_trial_id_from_environment(),
  )
  logging.info('model_dir in this trial is: %s', model_dir)
  if 'train' in FLAGS.mode:
    # Pure eval modes do not output yaml files. Otherwise continuous eval job
    # may race against the train job for writing the same file.
    train_utils.serialize_config(params, model_dir)

  # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
  # can have significant impact on model speeds by utilizing float16 in case of
  # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
  # dtype is float16
  if params.runtime.mixed_precision_dtype:
    performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
  distribution_strategy = distribute_utils.get_distribution_strategy(
      distribution_strategy=params.runtime.distribution_strategy,
      all_reduce_alg=params.runtime.all_reduce_alg,
      num_gpus=params.runtime.num_gpus,
      tpu_address=params.runtime.tpu,
  )
  with distribution_strategy.scope():
    task = task_factory.get_task(params.task, logging_dir=model_dir)

  train_lib.run_experiment(
      distribution_strategy=distribution_strategy,
      task=task,
      mode=FLAGS.mode,
      params=params,
      model_dir=model_dir,
  )

  train_utils.save_gin_config(FLAGS.mode, model_dir)

  eval_metric_name = get_best_eval_metric(_OBJECTIVE.value, params)

  eval_filepath = os.path.join(
      model_dir, constants.BEST_CKPT_DIRNAME, constants.BEST_CKPT_EVAL_FILENAME
  )
  logging.info('Load eval metrics from: %s.', eval_filepath)
  wait_for_evaluation_file(eval_filepath, _MAX_EVAL_WAIT_TIME.value)

  with tf.io.gfile.GFile(eval_filepath, 'rb') as f:
    eval_metric_results = json.load(f)
    logging.info('eval metrics are: %s.', eval_metric_results)
    if (
        eval_metric_name in eval_metric_results
        and constants.BEST_CKPT_STEP_NAME in eval_metric_results
    ):
      hp_metric = eval_metric_results[eval_metric_name]
      hp_step = int(eval_metric_results[constants.BEST_CKPT_STEP_NAME])
      hpt = hypertune.HyperTune()
      hpt.report_hyperparameter_tuning_metric(
          hyperparameter_metric_tag=constants.HP_METRIC_TAG,
          metric_value=hp_metric,
          global_step=hp_step,
      )
      logging.info(
          'Send HP metric: %f and steps %d to hyperparameter tuning.',
          hp_metric,
          hp_step,
      )
    else:
      logging.info(
          'Either %s or %s is not included in the evaluation results: %s.',
          eval_metric_name,
          constants.BEST_CKPT_STEP_NAME,
          eval_metric_results,
      )