tzrec/utils/faiss_util.py (86 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional, Tuple
import faiss
import numpy as np
import numpy.typing as npt
import pyarrow as pa
from tzrec.datasets.dataset import create_reader
from tzrec.utils.logging_util import logger
def build_faiss_index(
embedding_input_path: str,
id_field: str,
embedding_field: str,
index_type: str,
batch_size: int = 1024,
ivf_nlist: int = 1000,
hnsw_M: int = 32,
hnsw_efConstruction: int = 200,
reader_type: Optional[str] = None,
odps_data_quota_name: str = "pay-as-you-go",
) -> Tuple[faiss.Index, npt.NDArray]:
"""Build faiss index.
Args:
embedding_input_path (str): path to embedding table.
id_field (str): id field name in table.
embedding_field (str): embedding field name in table.
index_type (str): index type, available is ["IVFFlatIP", "HNSWFlatIP",
"IVFFlatL2", "HNSWFlatL2"].
batch_size (int): table read batch_size.
ivf_nlist (int): nlist of IVFFlat index.
hnsw_M (int): M of HNSWFlat index.
hnsw_efConstruction (int): efConstruction of HNSWFlat index.
reader_type (str, optional): specify the input path reader type,
if we cannot infer from input_path.
odps_data_quota_name (str): maxcompute storage api/tunnel data quota name.
Returns:
index (faiss.Index): faiss index.
index_id_map (NDArray): a list of embedding ids for mapping
continuous ids to origin id.
"""
is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0
reader = create_reader(
input_path=embedding_input_path,
batch_size=batch_size,
selected_cols=[id_field, embedding_field],
reader_type=reader_type,
quota_name=odps_data_quota_name,
)
index_id_map = []
embeddings = []
embedding_dim = None
for i, data in enumerate(reader.to_batches()):
eid_data = data[id_field]
emb_data = data[embedding_field]
index_id_map.extend(eid_data.tolist())
if not pa.types.is_list(emb_data.type):
emb_data = emb_data.cast(pa.string())
emb_data = pa.compute.split_pattern(emb_data, ",")
emb_data = emb_data.cast(pa.list_(pa.float32()), safe=False)
embeddings.append(np.stack(emb_data.to_numpy(zero_copy_only=False)))
if embedding_dim is None:
embedding_dim = len(emb_data[0])
if is_local_rank_zero and i % 100 == 0:
logger.info(f"Reading {len(index_id_map)} embeddings...")
if is_local_rank_zero:
logger.info("Building faiss index...")
if index_type.endswith("IP"):
# pyre-ignore [16]
quantizer = faiss.IndexFlatIP(embedding_dim)
# pyre-ignore [16]
metric_type = faiss.METRIC_INNER_PRODUCT
elif index_type.endswith("L2"):
# pyre-ignore [16]
quantizer = faiss.IndexFlatL2(embedding_dim)
# pyre-ignore [16]
metric_type = faiss.METRIC_L2
else:
raise ValueError(f"Unknown metric_type in index {index_type}.")
if index_type.startswith("IVFFlat"):
# pyre-ignore [16]
index = faiss.IndexIVFFlat(quantizer, embedding_dim, ivf_nlist, metric_type)
elif index_type.startswith("HNSWFlat"):
# pyre-ignore [16]
index = faiss.IndexHNSWFlat(embedding_dim, hnsw_M, metric_type)
index.hnsw.efConstruction = hnsw_efConstruction
else:
raise ValueError(f"Unknown index_type: {index_type}")
# pyre-ignore [16]
if faiss.get_num_gpus() > 0:
# pyre-ignore [16]
res = faiss.StandardGpuResources()
# pyre-ignore [16]
index = faiss.index_cpu_to_gpu(res, int(os.environ.get("LOCAL_RANK", 0)), index)
embeddings = np.concatenate(embeddings)
if index_type.startswith("IVFFlat"):
index.train(embeddings)
index.add(embeddings)
if is_local_rank_zero:
logger.info("Build embeddings finished.")
return index, np.array(index_id_map, dtype=str)
def write_faiss_index(
index: faiss.Index, index_id_map: npt.NDArray, output_dir: str
) -> None:
"""Write faiss index.
Args:
index (faiss.Index): faiss index.
index_id_map (NDArray): a list of embedding ids
for mapping continuous ids to origin id.
output_dir (str): index output dir.
"""
if hasattr(index, "getResources"): # gpu index
# pyre-ignore [16]
index = faiss.index_gpu_to_cpu(index)
# pyre-ignore [16]
faiss.write_index(index, os.path.join(output_dir, "faiss_index"))
with open(os.path.join(output_dir, "id_mapping"), "w") as f:
for eid in index_id_map:
f.write(f"{eid}\n")
def search_faiss_index(
index: faiss.Index, index_id_map: npt.NDArray, query: npt.NDArray, k: int
) -> Tuple[npt.NDArray, npt.NDArray]:
"""Search faiss index.
Args:
index (faiss.Index): faiss index.
index_id_map (NDArray): a list of embedding ids for mapping
continuous ids to origin id.
query (NDArray): search query.
k (int): top k.
Returns:
distances (NDArray): a array of distances.
ids (NDArray): a array of ids.
"""
distances, faiss_ids = index.search(query, k)
ids = index_id_map[faiss_ids]
return distances, ids