in retail/recommendation-system/bqml-scann/tfx_pipeline/scann_indexer.py [0:0]
def load_embeddings(embedding_files_pattern, schema_file_path):
embeddings = list()
vocabulary = list()
logging.info('Loading schema...')
schema = tfdv.load_schema_text(schema_file_path)
feature_sepc = schema_utils.schema_as_feature_spec(schema).feature_spec
logging.info('Schema is loaded.')
def _gzip_reader_fn(filenames):
return tf.data.TFRecordDataset(filenames, compression_type='GZIP')
dataset = tf.data.experimental.make_batched_features_dataset(
embedding_files_pattern,
batch_size=1,
num_epochs=1,
features=feature_sepc,
reader=_gzip_reader_fn,
shuffle=False
)
# Read embeddings from tfrecord files.
logging.info('Loading embeddings from files...')
for tfrecord_batch in dataset:
vocabulary.append(tfrecord_batch["item_Id"].numpy()[0][0].decode())
embedding = tfrecord_batch["embedding"].numpy()[0]
normalized_embedding = embedding / np.linalg.norm(embedding)
embeddings.append(normalized_embedding)
logging.info('Embeddings loaded.')
embeddings = np.array(embeddings)
return vocabulary, embeddings