def main()

in tensorflow_hub/tools/module_search/search.py [0:0]


def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")

  if not FLAGS.dataset:
    raise app.UsageError("--dataset is a required argument.")

  module_list = []
  if FLAGS.module:
    module_list.extend(FLAGS.module)

  if FLAGS.module_list:
    with tf.io.gfile.GFile(FLAGS.module_list) as f:
      lines = f.read().split("\n")
      module_list.extend([l for l in lines if l and not l.startswith("#")])

  if not module_list:
    raise app.UsageError(
        "Use --module or --module_list to define which modules to search.")

  ds_sections = FLAGS.dataset.split("#")
  dataset = ds_sections[0]
  train_examples = int(ds_sections[1]) if len(ds_sections) != 1 else None
  data_spec = {
    "dataset": dataset,
    "split": "train",
    "num_examples": train_examples,
  }

  results = []
  for module in module_list:
    results.append((
        module, data_spec,
        compute_score(module_spec=module, data_spec=data_spec)))

  df = pd.DataFrame(results, columns=["module", "data", "1nn"])
  df = df.filter(["module", "1nn"])
  df.sort_values(["1nn"])
  df.reset_index(drop=True)
  df.set_index("module")

  with pd.option_context(
      "display.max_rows", None,
      "display.max_columns", None,
      "display.precision", 3,
      "max_colwidth", -1,  # Don't truncate columns (e.g. module name).
      "display.expand_frame_repr", False,  # Don't wrap output.
  ):
    print("# Module ranking for %s" % data_spec)
    print(df)