def main()

in community-content/vertex_model_garden/model_oss/tfvision/serving/export_oss_saved_model.py [0:0]


def main(_) -> None:
  if (
      _MAX_TRIAL_COUNT.present
      and _EVALUATION_METRIC.present
      and _CONFIG_FILE.present
  ):
    best_ckpt_dir, _ = export_automl_oss_saved_model_lib.get_best_oss_trial(
        _CHECKPOINT_PATH.value, _MAX_TRIAL_COUNT.value, _EVALUATION_METRIC.value
    )
    config_filepath = _CONFIG_FILE.value
  elif _CHECKPOINT_PATH.present and _CONFIG_FILE.present:
    best_ckpt_dir = _CHECKPOINT_PATH.value
    config_filepath = _CONFIG_FILE.value
  elif (
      _PROJECT_NAME.present
      and _LOCATION.present
      and _HPT_JOB_ID.present
      and _HPT_RESULT_DIR.present
  ):
    # Reads HPT results by project and location and hpt_job_id.
    best_ckpt_dir = get_best_hpt_trials(
        _PROJECT_NAME.value,
        _LOCATION.value,
        _HPT_JOB_ID.value,
        _HPT_RESULT_DIR.value,
    )
    config_filepath = [
        os.path.join(best_ckpt_dir, automl_constants.CFG_FILENAME)
    ]
  else:
    raise ValueError('No checkpoint path or HTP Job parameters given.')

  params = exp_factory.get_exp_config(_EXPERIMENT.value)
  for config_file in config_filepath or []:
    params = hyperparams.override_params_dict(
        params, config_file, is_strict=False
    )
  if _OBJECTIVE.value == constants.OBJECTIVE_IMAGE_OBJECT_DETECTION:
    if _YOLO_KEY in _EXPERIMENT.value:
      params = hyperparams.override_params_dict(
          params, _PARAMS_OVERRIDE_YOLO, is_strict=False
      )
    else:
      params = hyperparams.override_params_dict(
          params, _PARAMS_OVERRIDE_IOD, is_strict=False
      )
  elif _OBJECTIVE.value == constants.OBJECTIVE_IMAGE_SEGMENTATION:
    params = hyperparams.override_params_dict(
        params, _PARAMS_OVERRIDE_ISG, is_strict=True
    )

  if _USE_BIGSTORE.value:
    params = change_handle(params)

  params.validate()
  params.lock()

  if best_ckpt_dir and not best_ckpt_dir.endswith(
      params.trainer.best_checkpoint_export_subdir
  ):
    best_ckpt_dir = os.path.join(
        best_ckpt_dir, params.trainer.best_checkpoint_export_subdir
    )

  if (
      _LABEL_MAP_PATH.value
      or _LABEL_PATH.value
      or _OBJECTIVE.value == constants.OBJECTIVE_IMAGE_SEGMENTATION
  ):
    export_automl_oss_saved_model_lib.export_inference_graph(
        input_type=_IMAGE_TYPE.value,
        batch_size=_BATCH_SIZE.value,
        input_image_size=[int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')],
        params=params,
        checkpoint_path=best_ckpt_dir,
        label_map_path=_LABEL_MAP_PATH.value,
        label_path=_LABEL_PATH.value,
        export_dir=_EXPORT_DIR.value,
        export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value,
        input_name=_INPUT_NAME.value,
        objective=_OBJECTIVE.value,
    )
  elif _YOLO_KEY in _EXPERIMENT.value:
    export_automl_oss_saved_model_lib.export_inference_graph(
        input_type=_IMAGE_TYPE.value,
        batch_size=_BATCH_SIZE.value,
        input_image_size=[int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')],
        params=params,
        checkpoint_path=best_ckpt_dir,
        export_dir=_EXPORT_DIR.value,
        export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value,
        input_name=_INPUT_NAME.value,
        objective=_OBJECTIVE.value,
    )
  else:
    export_oss_saved_model_lib.export_inference_graph(
        input_type=_IMAGE_TYPE.value,
        batch_size=_BATCH_SIZE.value,
        input_image_size=[int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')],
        params=params,
        checkpoint_path=best_ckpt_dir,
        export_dir=_EXPORT_DIR.value,
        export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value,
        input_name=_INPUT_NAME.value,
    )