in retail/recommendation-system/bqml-scann/tfx_pipeline/lookup_creator.py [0:0]
def __init__(self, embedding_files_prefix, schema_file_path, **kwargs):
super(EmbeddingLookup, self).__init__(**kwargs)
vocabulary = list()
embeddings = list()
logging.info('Loading schema...')
schema = tfdv.load_schema_text(schema_file_path)
feature_sepc = schema_utils.schema_as_feature_spec(schema).feature_spec
logging.info('Schema is loadded.')
def _gzip_reader_fn(filenames):
return tf.data.TFRecordDataset(filenames, compression_type='GZIP')
dataset = tf.data.experimental.make_batched_features_dataset(
embedding_files_prefix,
batch_size=1,
num_epochs=1,
features=feature_sepc,
reader=_gzip_reader_fn,
shuffle=False
)
# Read embeddings from tfrecord files.
logging.info('Loading embeddings from files ...')
for tfrecord_batch in dataset:
vocabulary.append(tfrecord_batch["item_Id"].numpy()[0][0].decode())
embeddings.append(tfrecord_batch["embedding"].numpy()[0])
logging.info('Embeddings loaded.')
embedding_size = len(embeddings[0])
oov_embedding = np.zeros((1, embedding_size))
self.embeddings = np.append(np.array(embeddings), oov_embedding, axis=0)
logging.info(f'Embeddings: {self.embeddings.shape}')
# Write vocabualry file.
logging.info('Writing vocabulary to file ...')
with open(VOCABULARY_FILE_NAME, 'w') as f:
for item in vocabulary:
f.write(f'{item}\n')
logging.info('Vocabulary file written and will be added as a model asset.')
self.vocabulary_file = tf.saved_model.Asset(VOCABULARY_FILE_NAME)
initializer = tf.lookup.KeyValueTensorInitializer(
keys=vocabulary, values=list(range(len(vocabulary))))
self.token_to_id = tf.lookup.StaticHashTable(
initializer, default_value=len(vocabulary))