def load_embeddings()

in retail/recommendation-system/bqml-scann/tfx_pipeline/scann_indexer.py [0:0]


def load_embeddings(embedding_files_pattern, schema_file_path):

  embeddings = list()
  vocabulary = list()
  
  logging.info('Loading schema...')
  schema = tfdv.load_schema_text(schema_file_path)
  feature_sepc = schema_utils.schema_as_feature_spec(schema).feature_spec
  logging.info('Schema is loaded.')

  def _gzip_reader_fn(filenames):
    return tf.data.TFRecordDataset(filenames, compression_type='GZIP')
    
  dataset = tf.data.experimental.make_batched_features_dataset(
    embedding_files_pattern, 
    batch_size=1, 
    num_epochs=1,
    features=feature_sepc,
    reader=_gzip_reader_fn,
    shuffle=False
  )

  # Read embeddings from tfrecord files.
  logging.info('Loading embeddings from files...')
  for tfrecord_batch in dataset:
    vocabulary.append(tfrecord_batch["item_Id"].numpy()[0][0].decode())
    embedding = tfrecord_batch["embedding"].numpy()[0]
    normalized_embedding = embedding / np.linalg.norm(embedding)
    embeddings.append(normalized_embedding)
  logging.info('Embeddings loaded.')
  embeddings = np.array(embeddings)
  
  return vocabulary, embeddings