in retail/recommendation-system/bqml-scann/embeddings_lookup/lookup_creator.py [0:0]
def __init__(self, embedding_files_prefix, **kwargs):
super(EmbeddingLookup, self).__init__(**kwargs)
vocabulary = list()
embeddings = list()
# Read embeddings from csv files.
print('Loading embeddings from files...')
for embedding_file in tf.io.gfile.glob(embedding_files_prefix):
print(f'Loading embeddings in {embedding_file} ...')
with tf.io.gfile.GFile(embedding_file, 'r') as lines:
for line in lines:
try:
line_parts = line.split(',')
item = line_parts[0]
embedding = np.array([float(v) for v in line_parts[1:]])
vocabulary.append(item)
embeddings.append(embedding)
except: pass
print('Embeddings loaded.')
embedding_size = len(embeddings[0])
oov_embedding = np.zeros((1, embedding_size))
self.embeddings = np.append(np.array(embeddings), oov_embedding, axis=0)
print(f'Embeddings: {self.embeddings.shape}')
# Write vocabulary file.
print('Writing vocabulary to file...')
with open(VOCABULARY_FILE_NAME, 'w') as f:
for item in vocabulary:
f.write(f'{item}\n')
print('Vocabulary file written and will be added as a model asset.')
self.vocabulary_file = tf.saved_model.Asset(VOCABULARY_FILE_NAME)
initializer = tf.lookup.KeyValueTensorInitializer(
keys=vocabulary, values=list(range(len(vocabulary))))
self.token_to_id = tf.lookup.StaticHashTable(
initializer, default_value=len(vocabulary))