def get_best_eval_metric()

in community-content/vertex_model_garden/model_oss/tfvision/train_hpt_oss.py [0:0]


def get_best_eval_metric(objective: str, params: Any) -> str:
  """Gets best eval metric.

  Args:
    objective: The objective of this training job.
    params: Experiment config.

  Returns:
    Eval metric to use.

  Raises:
    ValueError: If params does not have best_checkpoint_eval_metric set and the
      objective is not valid.
  """
  try:
    eval_metric_name = params.trainer.best_checkpoint_eval_metric
  except AttributeError:
    eval_metric_name = None

  if not eval_metric_name:
    # If eval metric is not given in params, use the default value.
    if objective == constants.OBJECTIVE_IMAGE_CLASSIFICATION:
      try:
        is_multilabel = params.task.train_data.is_multilabel
      except AttributeError:
        # Set default.
        is_multilabel = False
      if is_multilabel:
        eval_metric_name = (
            constants.IMAGE_CLASSIFICATION_MULTI_LABEL_BEST_EVAL_METRIC
        )
      else:
        eval_metric_name = (
            constants.IMAGE_CLASSIFICATION_SINGLE_LABEL_BEST_EVAL_METRIC
        )
    elif objective == constants.OBJECTIVE_IMAGE_OBJECT_DETECTION:
      eval_metric_name = constants.IMAGE_OBJECT_DETECTION_BEST_EVAL_METRIC
    elif objective == constants.OBJECTIVE_IMAGE_SEGMENTATION:
      eval_metric_name = constants.IMAGE_SEGMENTATION_BEST_EVAL_METRIC
    else:
      raise ValueError(
          'The objective must be {}, {}, or {}.'.format(
              constants.OBJECTIVE_IMAGE_CLASSIFICATION,
              constants.OBJECTIVE_IMAGE_OBJECT_DETECTION,
              constants.OBJECTIVE_IMAGE_SEGMENTATION,
          )
      )
  return eval_metric_name