in paq/retrievers/build_index.py [0:0]
def build_index_streaming(cached_embeddings_path,
output_path,
hnsw=False,
sq8_quantization=False,
fp16_quantization=False,
store_n=256,
ef_search=32,
ef_construction=80,
sample_fraction=0.1,
indexing_batch_size=5000000,
):
vector_size = get_vectors_dim(cached_embeddings_path)
if hnsw:
if sq8_quantization:
index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_8bit, store_n)
elif fp16_quantization:
index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_fp16, store_n)
else:
index = faiss.IndexHNSWFlat(vector_size + 1, store_n)
index.hnsw.efSearch = ef_search
index.hnsw.efConstruction = ef_construction
else:
if sq8_quantization:
index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_8bit, faiss.METRIC_L2)
elif fp16_quantization:
index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_fp16, faiss.METRIC_L2)
else:
index = faiss.IndexIP(vector_size + 1, store_n)
vector_sample, max_phi, N = get_vector_sample(cached_embeddings_path, sample_fraction)
if hnsw:
vector_sample = augment_vectors(vector_sample, max_phi)
if sq8_quantization or fp16_quantization: # index requires training
vs = vector_sample.numpy()
logging.info(f'Training Quantizer with matrix of shape {vs.shape}')
index.train(vs)
del vs
del vector_sample
chunks_to_add = []
added = 0
for vector_chunk in parse_vectors_from_directory(cached_embeddings_path, as_chunks=True):
if hnsw:
vector_chunk = augment_vectors(vector_chunk, max_phi)
chunks_to_add.append(vector_chunk)
if sum(c.shape[0] for c in chunks_to_add) > indexing_batch_size:
logging.info(f'Adding Vectors {added} -> {added + to_add.shape[0]} of {N}')
to_add = torch.cat(chunks_to_add)
chunks_to_add = []
index.add(to_add)
added += 1
if len(chunks_to_add) > 0:
to_add = torch.cat(chunks_to_add).numpy()
index.add(to_add)
logging.info(f'Adding Vectors {added} -> {added + to_add.shape[0]} of {N}')
logger.info(f'Index Built, writing index to {output_path}')
faiss.write_index(index, output_path)
logger.info(f'Index dumped')
return index