in community-content/vertex_model_garden/model_oss/tfvision/serving/export_oss_saved_model.py [0:0]
def main(_) -> None:
if (
_MAX_TRIAL_COUNT.present
and _EVALUATION_METRIC.present
and _CONFIG_FILE.present
):
best_ckpt_dir, _ = export_automl_oss_saved_model_lib.get_best_oss_trial(
_CHECKPOINT_PATH.value, _MAX_TRIAL_COUNT.value, _EVALUATION_METRIC.value
)
config_filepath = _CONFIG_FILE.value
elif _CHECKPOINT_PATH.present and _CONFIG_FILE.present:
best_ckpt_dir = _CHECKPOINT_PATH.value
config_filepath = _CONFIG_FILE.value
elif (
_PROJECT_NAME.present
and _LOCATION.present
and _HPT_JOB_ID.present
and _HPT_RESULT_DIR.present
):
# Reads HPT results by project and location and hpt_job_id.
best_ckpt_dir = get_best_hpt_trials(
_PROJECT_NAME.value,
_LOCATION.value,
_HPT_JOB_ID.value,
_HPT_RESULT_DIR.value,
)
config_filepath = [
os.path.join(best_ckpt_dir, automl_constants.CFG_FILENAME)
]
else:
raise ValueError('No checkpoint path or HTP Job parameters given.')
params = exp_factory.get_exp_config(_EXPERIMENT.value)
for config_file in config_filepath or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=False
)
if _OBJECTIVE.value == constants.OBJECTIVE_IMAGE_OBJECT_DETECTION:
if _YOLO_KEY in _EXPERIMENT.value:
params = hyperparams.override_params_dict(
params, _PARAMS_OVERRIDE_YOLO, is_strict=False
)
else:
params = hyperparams.override_params_dict(
params, _PARAMS_OVERRIDE_IOD, is_strict=False
)
elif _OBJECTIVE.value == constants.OBJECTIVE_IMAGE_SEGMENTATION:
params = hyperparams.override_params_dict(
params, _PARAMS_OVERRIDE_ISG, is_strict=True
)
if _USE_BIGSTORE.value:
params = change_handle(params)
params.validate()
params.lock()
if best_ckpt_dir and not best_ckpt_dir.endswith(
params.trainer.best_checkpoint_export_subdir
):
best_ckpt_dir = os.path.join(
best_ckpt_dir, params.trainer.best_checkpoint_export_subdir
)
if (
_LABEL_MAP_PATH.value
or _LABEL_PATH.value
or _OBJECTIVE.value == constants.OBJECTIVE_IMAGE_SEGMENTATION
):
export_automl_oss_saved_model_lib.export_inference_graph(
input_type=_IMAGE_TYPE.value,
batch_size=_BATCH_SIZE.value,
input_image_size=[int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')],
params=params,
checkpoint_path=best_ckpt_dir,
label_map_path=_LABEL_MAP_PATH.value,
label_path=_LABEL_PATH.value,
export_dir=_EXPORT_DIR.value,
export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value,
input_name=_INPUT_NAME.value,
objective=_OBJECTIVE.value,
)
elif _YOLO_KEY in _EXPERIMENT.value:
export_automl_oss_saved_model_lib.export_inference_graph(
input_type=_IMAGE_TYPE.value,
batch_size=_BATCH_SIZE.value,
input_image_size=[int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')],
params=params,
checkpoint_path=best_ckpt_dir,
export_dir=_EXPORT_DIR.value,
export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value,
input_name=_INPUT_NAME.value,
objective=_OBJECTIVE.value,
)
else:
export_oss_saved_model_lib.export_inference_graph(
input_type=_IMAGE_TYPE.value,
batch_size=_BATCH_SIZE.value,
input_image_size=[int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')],
params=params,
checkpoint_path=best_ckpt_dir,
export_dir=_EXPORT_DIR.value,
export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value,
input_name=_INPUT_NAME.value,
)