retail/recommendation-system/bqml-scann/index_builder/builder/indexer.py (67 lines of code) (raw):

# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import scann import tensorflow as tf import numpy as np import math import pickle METRIC = 'dot_product' DIMENSIONS_PER_BLOCK = 2 ANISOTROPIC_QUANTIZATION_THRESHOLD = 0.2 NUM_NEIGHBOURS = 10 NUM_LEAVES_TO_SEARCH = 200 REORDER_NUM_NEIGHBOURS = 200 TOKENS_FILE_NAME = 'tokens' def load_embeddings(embedding_files_pattern): embedding_list = list() tokens = list() embed_files = tf.io.gfile.glob(embedding_files_pattern) print(f'{len(embed_files)} embedding files are found.') for file_idx, embed_file in enumerate(embed_files): print(f'Loading embeddings in file {file_idx+1} of {len(embed_files)}...') with tf.io.gfile.GFile(embed_file, 'r') as file_reader: lines = file_reader.readlines() for line in lines: parts = line.split(',') item_Id = parts[0] embedding = parts[1:] embedding = np.array([float(v) for v in embedding]) normalized_embedding = embedding / np.linalg.norm(embedding) embedding_list.append(normalized_embedding) tokens.append(item_Id) print(f'{len(embedding_list)} embeddings are loaded.') return tokens, np.array(embedding_list) def build_index(embeddings, num_leaves): data_size = embeddings.shape[0] if not num_leaves: num_leaves = int(math.sqrt(data_size)) print('Start building the ScaNN index...') scann_builder = scann.scann_ops.builder(embeddings, NUM_NEIGHBOURS, METRIC) scann_builder = scann_builder.tree( num_leaves=num_leaves, num_leaves_to_search=NUM_LEAVES_TO_SEARCH, training_sample_size=data_size) scann_builder = scann_builder.score_ah( DIMENSIONS_PER_BLOCK, anisotropic_quantization_threshold=ANISOTROPIC_QUANTIZATION_THRESHOLD) scann_builder = scann_builder.reorder(REORDER_NUM_NEIGHBOURS) scann_index = scann_builder.build() print('ScaNN index is built.') return scann_index def save_index(index, tokens, output_dir): print('Saving index as a SavedModel...') module = index.serialize_to_module() tf.saved_model.save( module, output_dir, signatures=None, options=None ) print(f'Index is saved to {output_dir}') print(f'Saving tokens file...') tokens_file_path = os.path.join(output_dir, TOKENS_FILE_NAME) with tf.io.gfile.GFile(tokens_file_path, 'wb') as handle: pickle.dump(tokens, handle, protocol=pickle.HIGHEST_PROTOCOL) print(f'Item file is saved to {tokens_file_path}.') def build(embedding_files_pattern, output_dir, num_leaves=None): print("Indexer started...") tokens, embeddings = load_embeddings(embedding_files_pattern) index = build_index(embeddings, num_leaves) save_index(index, tokens, output_dir) print("Indexer finished.")