in tensorflow_hub/tools/make_nearest_neighbour_index/index_builder.py [0:0]
def run(args):
"""Runs the index building process."""
embed_output_dir = args.embed_output_dir
output_dir = args.index_output_dir
num_trees = args.num_trees
index_file_path = os.path.join(output_dir, _INDEX_FILENAME)
mapping_file_path = os.path.join(output_dir, _MAPPING_FILENAME)
if tf.io.gfile.exists(output_dir):
print('Index output directory...')
tf.io.gfile.rmtree(output_dir)
print('Creating empty output directory...')
tf.io.gfile.makedirs(output_dir)
embed_files = tf.io.gfile.glob(os.path.join(embed_output_dir, '*.tfrecords'))
num_files = len(embed_files)
print('Found {} embedding file(s).'.format(num_files))
dimensions = _infer_dimensions(embed_files[0])
print('Embedding size: {}'.format(dimensions))
annoy_index = annoy.AnnoyIndex(dimensions, metric=_METRIC)
# Mapping between the item and its identifier in the index
mapping = {}
item_counter = 0
for i, embed_file in enumerate(embed_files):
print('Loading embeddings in file {} of {}...'.format(
i + 1, num_files))
dataset = tf.data.TFRecordDataset(embed_file)
for record in dataset.map(_parse_example):
item = record['item'].numpy().decode('utf-8')
embedding = record['embedding'].values.numpy()
mapping[item_counter] = item
annoy_index.add_item(item_counter, embedding)
item_counter += 1
if item_counter % 200000 == 0:
print('{} items loaded to the index'.format(item_counter))
print('A total of {} items added to the index'.format(item_counter))
print('Building the index with {} trees...'.format(num_trees))
annoy_index.build(n_trees=num_trees)
print('Index is successfully built.')
print('Saving index to disk...')
annoy_index.save(index_file_path)
print('Index is saved to disk. File size: {} GB'.format(
round(os.path.getsize(index_file_path) / float(1024**3), 2)))
annoy_index.unload()
print('Saving mapping to disk...')
with open(mapping_file_path, 'wb') as handle:
pickle.dump(mapping, handle, protocol=pickle.HIGHEST_PROTOCOL)
print('Mapping is saved to disk. File size: {} MB'.format(
round(os.path.getsize(mapping_file_path) / float(1024**2), 2)))
random_projection_file_path = os.path.join(
args.embed_output_dir, _RANDOM_PROJECTION_FILENAME)
if os.path.exists(random_projection_file_path):
shutil.copy(
random_projection_file_path, os.path.join(
args.index_output_dir, _RANDOM_PROJECTION_FILENAME))
print('Random projection matrix file copies to index output directory.')