in community-content/vertex_model_garden/model_oss/tfvision/train_hpt_oss.py [0:0]
def get_best_eval_metric(objective: str, params: Any) -> str:
"""Gets best eval metric.
Args:
objective: The objective of this training job.
params: Experiment config.
Returns:
Eval metric to use.
Raises:
ValueError: If params does not have best_checkpoint_eval_metric set and the
objective is not valid.
"""
try:
eval_metric_name = params.trainer.best_checkpoint_eval_metric
except AttributeError:
eval_metric_name = None
if not eval_metric_name:
# If eval metric is not given in params, use the default value.
if objective == constants.OBJECTIVE_IMAGE_CLASSIFICATION:
try:
is_multilabel = params.task.train_data.is_multilabel
except AttributeError:
# Set default.
is_multilabel = False
if is_multilabel:
eval_metric_name = (
constants.IMAGE_CLASSIFICATION_MULTI_LABEL_BEST_EVAL_METRIC
)
else:
eval_metric_name = (
constants.IMAGE_CLASSIFICATION_SINGLE_LABEL_BEST_EVAL_METRIC
)
elif objective == constants.OBJECTIVE_IMAGE_OBJECT_DETECTION:
eval_metric_name = constants.IMAGE_OBJECT_DETECTION_BEST_EVAL_METRIC
elif objective == constants.OBJECTIVE_IMAGE_SEGMENTATION:
eval_metric_name = constants.IMAGE_SEGMENTATION_BEST_EVAL_METRIC
else:
raise ValueError(
'The objective must be {}, {}, or {}.'.format(
constants.OBJECTIVE_IMAGE_CLASSIFICATION,
constants.OBJECTIVE_IMAGE_OBJECT_DETECTION,
constants.OBJECTIVE_IMAGE_SEGMENTATION,
)
)
return eval_metric_name