def run()

in tensorflow_hub/tools/make_nearest_neighbour_index/index_builder.py [0:0]


def run(args):
  """Runs the index building process."""

  embed_output_dir = args.embed_output_dir
  output_dir = args.index_output_dir
  num_trees = args.num_trees
  index_file_path = os.path.join(output_dir, _INDEX_FILENAME)
  mapping_file_path = os.path.join(output_dir, _MAPPING_FILENAME)

  if tf.io.gfile.exists(output_dir):
    print('Index output directory...')
    tf.io.gfile.rmtree(output_dir)
  print('Creating empty output directory...')
  tf.io.gfile.makedirs(output_dir)

  embed_files = tf.io.gfile.glob(os.path.join(embed_output_dir, '*.tfrecords'))
  num_files = len(embed_files)
  print('Found {} embedding file(s).'.format(num_files))

  dimensions = _infer_dimensions(embed_files[0])
  print('Embedding size: {}'.format(dimensions))

  annoy_index = annoy.AnnoyIndex(dimensions, metric=_METRIC)

  # Mapping between the item and its identifier in the index
  mapping = {}

  item_counter = 0
  for i, embed_file in enumerate(embed_files):
    print('Loading embeddings in file {} of {}...'.format(
        i + 1, num_files))
    dataset = tf.data.TFRecordDataset(embed_file)
    for record in dataset.map(_parse_example):
      item = record['item'].numpy().decode('utf-8')
      embedding = record['embedding'].values.numpy()
      mapping[item_counter] = item
      annoy_index.add_item(item_counter, embedding)
      item_counter += 1
      if item_counter % 200000 == 0:
        print('{} items loaded to the index'.format(item_counter))

  print('A total of {} items added to the index'.format(item_counter))

  print('Building the index with {} trees...'.format(num_trees))
  annoy_index.build(n_trees=num_trees)
  print('Index is successfully built.')

  print('Saving index to disk...')
  annoy_index.save(index_file_path)
  print('Index is saved to disk. File size: {} GB'.format(
      round(os.path.getsize(index_file_path) / float(1024**3), 2)))
  annoy_index.unload()

  print('Saving mapping to disk...')
  with open(mapping_file_path, 'wb') as handle:
    pickle.dump(mapping, handle, protocol=pickle.HIGHEST_PROTOCOL)
  print('Mapping is saved to disk. File size: {} MB'.format(
      round(os.path.getsize(mapping_file_path) / float(1024**2), 2)))

  random_projection_file_path = os.path.join(
      args.embed_output_dir, _RANDOM_PROJECTION_FILENAME)
  if os.path.exists(random_projection_file_path):
    shutil.copy(
        random_projection_file_path, os.path.join(
            args.index_output_dir, _RANDOM_PROJECTION_FILENAME))
    print('Random projection matrix file copies to index output directory.')