def __init__()

in retail/recommendation-system/bqml-scann/embeddings_lookup/lookup_creator.py [0:0]


  def __init__(self, embedding_files_prefix, **kwargs):
    super(EmbeddingLookup, self).__init__(**kwargs)

    vocabulary = list()
    embeddings = list()

    # Read embeddings from csv files.
    print('Loading embeddings from files...')
    for embedding_file in tf.io.gfile.glob(embedding_files_prefix):
      print(f'Loading embeddings in {embedding_file} ...')
      with tf.io.gfile.GFile(embedding_file, 'r') as lines:
        for line in lines:
          try:
            line_parts = line.split(',')
            item = line_parts[0]
            embedding = np.array([float(v) for v in line_parts[1:]])
            vocabulary.append(item)
            embeddings.append(embedding)
          except: pass
    print('Embeddings loaded.')
    
    embedding_size = len(embeddings[0])
    oov_embedding = np.zeros((1, embedding_size))
    self.embeddings = np.append(np.array(embeddings), oov_embedding, axis=0)
    print(f'Embeddings: {self.embeddings.shape}')

    # Write vocabulary file.
    print('Writing vocabulary to file...')
    with open(VOCABULARY_FILE_NAME, 'w') as f:
      for item in vocabulary: 
        f.write(f'{item}\n')
    print('Vocabulary file written and will be added as a model asset.')

    self.vocabulary_file = tf.saved_model.Asset(VOCABULARY_FILE_NAME)
    initializer = tf.lookup.KeyValueTensorInitializer(
        keys=vocabulary, values=list(range(len(vocabulary))))
    self.token_to_id = tf.lookup.StaticHashTable(
        initializer, default_value=len(vocabulary))