in tensorflow_hub/tools/module_search/search.py [0:0]
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
if not FLAGS.dataset:
raise app.UsageError("--dataset is a required argument.")
module_list = []
if FLAGS.module:
module_list.extend(FLAGS.module)
if FLAGS.module_list:
with tf.io.gfile.GFile(FLAGS.module_list) as f:
lines = f.read().split("\n")
module_list.extend([l for l in lines if l and not l.startswith("#")])
if not module_list:
raise app.UsageError(
"Use --module or --module_list to define which modules to search.")
ds_sections = FLAGS.dataset.split("#")
dataset = ds_sections[0]
train_examples = int(ds_sections[1]) if len(ds_sections) != 1 else None
data_spec = {
"dataset": dataset,
"split": "train",
"num_examples": train_examples,
}
results = []
for module in module_list:
results.append((
module, data_spec,
compute_score(module_spec=module, data_spec=data_spec)))
df = pd.DataFrame(results, columns=["module", "data", "1nn"])
df = df.filter(["module", "1nn"])
df.sort_values(["1nn"])
df.reset_index(drop=True)
df.set_index("module")
with pd.option_context(
"display.max_rows", None,
"display.max_columns", None,
"display.precision", 3,
"max_colwidth", -1, # Don't truncate columns (e.g. module name).
"display.expand_frame_repr", False, # Don't wrap output.
):
print("# Module ranking for %s" % data_spec)
print(df)