text-semantic-search/index_builder/builder/index.py (45 lines of code) (raw):

#!/usr/bin/python # # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf import numpy as np import logging import pickle import os from annoy import AnnoyIndex VECTOR_LENGTH = 512 METRIC = 'angular' def build_index(embedding_files_pattern, index_filename, num_trees=100): annoy_index = AnnoyIndex(VECTOR_LENGTH, metric=METRIC) mapping = {} embed_files = tf.gfile.Glob(embedding_files_pattern)[:250] logging.info('{} embedding files are found.'.format(len(embed_files))) item_counter = 0 for f, embed_file in enumerate(embed_files): logging.info('Loading embeddings in file {} of {}...'.format( f, len(embed_files))) record_iterator = tf.python_io.tf_record_iterator( path=embed_file) for string_record in record_iterator: example = tf.train.Example() example.ParseFromString(string_record) string_identifier = example.features.feature['id'].bytes_list.value[0] mapping[item_counter] = string_identifier embedding = np.array( example.features.feature['embedding'].float_list.value) annoy_index.add_item(item_counter, embedding) item_counter += 1 logging.info('Loaded {} items to the index'.format(item_counter)) logging.info('Start building the index with {} trees...'.format(num_trees)) annoy_index.build(n_trees=num_trees) logging.info('Index is successfully built.') logging.info('Saving index to disk...') annoy_index.save(index_filename) logging.info('Index is saved to disk.') logging.info("Index file size: {} GB".format( round(os.path.getsize(index_filename) / float(1024 ** 3), 2))) annoy_index.unload() logging.info('Saving mapping to disk...') with open(index_filename + '.mapping', 'wb') as handle: pickle.dump(mapping, handle, protocol=pickle.HIGHEST_PROTOCOL) logging.info('Mapping is saved to disk.') logging.info("Mapping file size: {} MB".format( round(os.path.getsize(index_filename + '.mapping') / float(1024 ** 2), 2)))