easy_rec/python/tools/faiss_index_pai.py (95 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from __future__ import print_function
import logging
import os
import sys
import faiss
import numpy as np
import tensorflow as tf
from easy_rec.python.utils import io_util
logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
tf.app.flags.DEFINE_string('tables', '', 'tables passed by pai command')
tf.app.flags.DEFINE_integer('batch_size', 1024, 'batch size')
tf.app.flags.DEFINE_integer('embedding_dim', 32, 'embedding dimension')
tf.app.flags.DEFINE_string('index_output_dir', '', 'index output directory')
tf.app.flags.DEFINE_string('index_type', 'IVFFlat', 'index type')
tf.app.flags.DEFINE_integer('ivf_nlist', 1000, 'nlist')
tf.app.flags.DEFINE_integer('hnsw_M', 32, 'hnsw M')
tf.app.flags.DEFINE_integer('hnsw_efConstruction', 200, 'hnsw efConstruction')
tf.app.flags.DEFINE_integer('debug', 0, 'debug index')
FLAGS = tf.app.flags.FLAGS
def main(argv):
reader = tf.python_io.TableReader(
FLAGS.tables, slice_id=0, slice_count=1, capacity=FLAGS.batch_size * 2)
i = 0
id_map_f = tf.gfile.GFile(
os.path.join(FLAGS.index_output_dir, 'id_mapping'), 'w')
embeddings = []
while True:
try:
records = reader.read(FLAGS.batch_size)
for j, record in enumerate(records):
if isinstance(record[0], bytes):
eid = record[0].decode('utf-8')
id_map_f.write('%s\n' % eid)
embeddings.extend(
[list(map(float, record[1].split(b','))) for record in records])
i += 1
if i % 100 == 0:
logging.info('read %d embeddings.' % (i * FLAGS.batch_size))
except tf.python_io.OutOfRangeException:
break
reader.close()
id_map_f.close()
logging.info('Building faiss index..')
if FLAGS.index_type == 'IVFFlat':
quantizer = faiss.IndexFlatIP(FLAGS.embedding_dim)
index = faiss.IndexIVFFlat(quantizer, FLAGS.embedding_dim, FLAGS.ivf_nlist,
faiss.METRIC_INNER_PRODUCT)
elif FLAGS.index_type == 'HNSWFlat':
index = faiss.IndexHNSWFlat(FLAGS.embedding_dim, FLAGS.hnsw_M,
faiss.METRIC_INNER_PRODUCT)
index.hnsw.efConstruction = FLAGS.hnsw_efConstruction
else:
raise NotImplementedError
embeddings = np.array(embeddings)
if FLAGS.index_type == 'IVFFlat':
logging.info('train embeddings...')
index.train(embeddings)
logging.info('build embeddings...')
index.add(embeddings)
faiss.write_index(index, 'faiss_index')
with tf.gfile.GFile(
os.path.join(FLAGS.index_output_dir, 'faiss_index'), 'wb') as f_out:
with open('faiss_index', 'rb') as f_in:
f_out.write(f_in.read())
if FLAGS.debug != 0:
# IVFFlat
for ivf_nlist in [100, 500, 1000, 2000]:
quantizer = faiss.IndexFlatIP(FLAGS.embedding_dim)
index = faiss.IndexIVFFlat(quantizer, FLAGS.embedding_dim, ivf_nlist,
faiss.METRIC_INNER_PRODUCT)
index.train(embeddings)
index.add(embeddings)
index_name = 'faiss_index_ivfflat_nlist%d' % ivf_nlist
faiss.write_index(index, index_name)
with tf.gfile.GFile(
os.path.join(FLAGS.index_output_dir, index_name), 'wb') as f_out:
with open(index_name, 'rb') as f_in:
f_out.write(f_in.read())
# HNSWFlat
for hnsw_M in [16, 32, 64, 128]:
for hnsw_efConstruction in [64, 128, 256, 512, 1024, 2048, 4096, 8196]:
if hnsw_efConstruction < hnsw_M * 2:
continue
index = faiss.IndexHNSWFlat(FLAGS.embedding_dim, hnsw_M,
faiss.METRIC_INNER_PRODUCT)
index.hnsw.efConstruction = hnsw_efConstruction
index.add(embeddings)
index_name = 'faiss_index_hnsw_M%d_ef%d' % (hnsw_M, hnsw_efConstruction)
faiss.write_index(index, index_name)
with tf.gfile.GFile(
os.path.join(FLAGS.index_output_dir, index_name), 'wb') as f_out:
with open(index_name, 'rb') as f_in:
f_out.write(f_in.read())
if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()