def __init__()

in retail/recommendation-system/bqml-scann/tfx_pipeline/lookup_creator.py [0:0]


  def __init__(self, embedding_files_prefix, schema_file_path, **kwargs):
    super(EmbeddingLookup, self).__init__(**kwargs)
    
    vocabulary = list()
    embeddings = list()

    logging.info('Loading schema...')
    schema = tfdv.load_schema_text(schema_file_path)
    feature_sepc = schema_utils.schema_as_feature_spec(schema).feature_spec
    logging.info('Schema is loadded.')
    
    def _gzip_reader_fn(filenames):
      return tf.data.TFRecordDataset(filenames, compression_type='GZIP')
    
    dataset = tf.data.experimental.make_batched_features_dataset(
      embedding_files_prefix, 
      batch_size=1, 
      num_epochs=1,
      features=feature_sepc,
      reader=_gzip_reader_fn,
      shuffle=False
    )

    # Read embeddings from tfrecord files.
    logging.info('Loading embeddings from files ...')
    for tfrecord_batch in dataset:
      vocabulary.append(tfrecord_batch["item_Id"].numpy()[0][0].decode())
      embeddings.append(tfrecord_batch["embedding"].numpy()[0])
    logging.info('Embeddings loaded.')
    
    embedding_size = len(embeddings[0])
    oov_embedding = np.zeros((1, embedding_size))
    self.embeddings = np.append(np.array(embeddings), oov_embedding, axis=0)
    logging.info(f'Embeddings: {self.embeddings.shape}')

    # Write vocabualry file.
    logging.info('Writing vocabulary to file ...')
    with open(VOCABULARY_FILE_NAME, 'w') as f:
      for item in vocabulary: 
        f.write(f'{item}\n')
    logging.info('Vocabulary file written and will be added as a model asset.')

    self.vocabulary_file = tf.saved_model.Asset(VOCABULARY_FILE_NAME)
    initializer = tf.lookup.KeyValueTensorInitializer(
        keys=vocabulary, values=list(range(len(vocabulary))))
    self.token_to_id = tf.lookup.StaticHashTable(
        initializer, default_value=len(vocabulary))