in tensorflow_hub/tools/make_nearest_neighbour_index/make_nearest_neighbour_index.py [0:0]
def validate_args(args):
"""Validates the command line arguments specified by the user."""
if len(args) < 2 or args[1] not in ["generate", "build", "e2e", "query"]:
raise ValueError("You need to specify one of four operations: "
"generate | build | e2e | query")
def _validate_generate_args():
"""Validates generate operation args."""
if not FLAGS.data_file_pattern:
raise ValueError(
"You must provide --data_file_pattern to generate embeddings for.")
if not FLAGS.module_url:
raise ValueError(
"You must provide --module_url to use for embeddings generation.")
if not FLAGS.embed_output_dir:
raise ValueError(
"You must provide --embed_output_dir to store the embedding files.")
if FLAGS.projected_dim and FLAGS.projected_dim < 1:
raise ValueError("--projected_dim must be a positive integer value.")
def _validate_build_args(e2e=False):
"""Validates build operation args."""
if not FLAGS.embed_output_dir and not e2e:
raise ValueError(
"You must provide --embed_output_dir of the embeddings"
"to build the ANN index for.")
if not FLAGS.index_output_dir:
raise ValueError(
"You must provide --index_output_dir to store the index files.")
if not FLAGS.num_trees or FLAGS.num_trees < 1:
raise ValueError(
"You must provide --num_trees as a positive integer value.")
def _validate_query_args():
if not FLAGS.module_url:
raise ValueError("You must provide --module_url to use for query.")
if not FLAGS.index_output_dir:
raise ValueError("You must provide --index_output_dir to use for query.")
operation = args[1]
if operation == "generate":
_validate_generate_args()
elif operation == "build":
_validate_build_args()
elif operation == "e2e":
_validate_generate_args()
_validate_build_args(True)
else:
_validate_query_args()
return operation