tzrec/tools/hitrate.py (381 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from collections import OrderedDict
from typing import Any, List, Optional, Tuple
import numpy as np
import numpy.typing as npt
import pyarrow as pa
import torch
from numpy.linalg import norm
from torch import distributed as dist
from tzrec.datasets.dataset import create_reader, create_writer
from tzrec.main import init_process_group
from tzrec.utils import faiss_util
from tzrec.utils.logging_util import logger
def batch_hitrate(
src_ids: List[Any], # pyre-ignore [2]
recall_ids: npt.NDArray,
gt_items: List[List[str]],
max_num_interests: int,
num_interests: Optional[List[int]] = None,
) -> Tuple[List[float], List[List[str]], float, float]:
"""Compute hitrate of a batch of src ids.
Args:
src_ids (list): trigger id, a list.
recall_ids (NDArray): recalled ids by src_ids, a numpy array.
gt_items (list): batch of ground truth item ids list, a list of list.
max_num_interests (int): max number of interests.
num_interests (list): some models have different number of interests.
Returns:
hitrates (list): hitrate of src_ids, a list.
hit_ids (list): hit cases, a list of list.
hits (int): total hit counts of a batch of src ids, a scalar.
gt_count (int): total ground truth items num of a batch of src ids, a scalar.
"""
hit_ids = []
hitrates = []
hits = 0.0
gt_count = 0.0
for idx, src_id in enumerate(src_ids):
recall_id = recall_ids[idx]
gt_item = set(gt_items[idx])
gt_items_size = len(gt_item)
hit_id_set = set()
if gt_items_size == 0: # just skip invalid record.
logger.warning(
"Id {:d} has no related items sequence, just skip.".format(src_id)
)
continue
for interest_id in range(max_num_interests):
if num_interests and interest_id >= num_interests[idx]:
break
hit_id_set |= set(recall_id[interest_id]) & gt_item
hit_count = float(len(hit_id_set))
hitrates.append(hit_count / gt_items_size)
hits += hit_count
gt_count += gt_items_size
hit_ids.append(list(hit_id_set))
return hitrates, hit_ids, hits, gt_count
def interest_merge(
user_emb: npt.NDArray,
recall_distances: npt.NDArray,
recall_ids: npt.NDArray,
top_k: int,
num_interests: int,
index_type: str,
) -> Tuple[npt.NDArray, npt.NDArray]:
"""Merge the recall results of different interests.
Args:
user_emb (NDArray): user embedding.
recall_distances (NDArray): recall distances.
recall_ids (NDArray): recall ids.
top_k (int): top k candidates.
num_interests (int): number of interests.
index_type (str): index type.
Returns:
recall_distances (NDArray): merged recall distances.
recall_ids(NDArray): merged recall ids.
"""
# In case of all-zero query vector, the corresponding knn results
# should be removed since faiss returns random target for all-zero query.
if index_type.endswith("IP"):
recall_distances = np.minimum(
recall_distances,
np.tile(
(
(norm(user_emb, axis=-1, keepdims=True) != 0.0).astype("float") * 2
- 1
)
* 1e32,
(1, top_k),
),
)
else: # L2 distance
recall_distances = np.maximum(
recall_distances,
np.tile(
(
(norm(user_emb, axis=-1, keepdims=True) == 0.0).astype("float") * 2
- 1
)
* 1e32,
(1, top_k),
),
)
recall_distances_flat = recall_distances.reshape(
[-1, num_interests * recall_distances.shape[-1]]
)
recall_ids_flat = recall_ids.reshape(
[-1, args.num_interests * recall_ids.shape[-1]]
)
sort_idx = np.argsort(recall_distances_flat, axis=-1)
if index_type.endswith("IP"): # inner product should be sorted in descending order
sort_idx = sort_idx[:, ::-1]
recall_distances_flat_sorted = recall_distances_flat[
np.arange(recall_distances_flat.shape[0])[:, np.newaxis], sort_idx
]
recall_ids_flat_sorted = recall_ids_flat[
np.arange(recall_ids_flat.shape[0])[:, np.newaxis], sort_idx
]
# get unique candidates
recall_distances_flat_sorted_pad = np.concatenate(
[
recall_distances_flat_sorted,
np.zeros((recall_distances_flat_sorted.shape[0], 1)),
],
axis=-1,
)
# compute diff value between consecutive distances
recall_distances_diff = (
recall_distances_flat_sorted_pad[:, 0:-1]
- recall_distances_flat_sorted_pad[:, 1:]
)
if index_type.endswith("IP"):
pad_value = -1e32
else:
pad_value = 1e32
# zero diff positions are dulipcated values, so we pad them with a pad value
recall_distances_unique = np.where(
recall_distances_diff == 0, pad_value, recall_distances_flat_sorted
)
# sort again to get the unique candidates, duplicated values are -1e32(IP)
# or 1e32(L2), so they are moved to the end
sort_idx_new = np.argsort(recall_distances_unique, axis=-1)
if index_type.endswith("IP"):
sort_idx_new = sort_idx_new[:, ::-1]
recall_distances = recall_distances_flat_sorted[
np.arange(recall_distances_flat_sorted.shape[0])[:, np.newaxis],
sort_idx_new[:, 0:top_k],
]
recall_ids = recall_ids_flat_sorted[
np.arange(recall_ids_flat_sorted.shape[0])[:, np.newaxis],
sort_idx_new[:, 0:top_k],
]
recall_distances = recall_distances.reshape([-1, 1, recall_distances.shape[-1]])
recall_ids = recall_ids.reshape([-1, 1, recall_distances.shape[-1]])
return recall_distances, recall_ids
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--user_gt_input",
type=str,
default=None,
help="Path to user groudtruth & embedding table with columns [request_id, "
"gt_items, user_tower_emb]",
)
parser.add_argument(
"--item_embedding_input",
type=str,
default=None,
help="Path to item embedding table with columns [item_id, item_tower_emb]",
)
parser.add_argument(
"--total_hitrate_output",
type=str,
default=None,
help="Path to hitrate table with columns [hitrate]",
)
parser.add_argument(
"--hitrate_details_output",
type=str,
default=None,
help="Path to hitrate detail table with columns [id, topk_ids, "
"topk_dists, hitrate, hit_ids]",
)
parser.add_argument(
"--batch_size",
type=int,
default=1024,
help="batch size.",
)
parser.add_argument(
"--index_type",
type=str,
default="IVFFlatIP",
choices=["IVFFlatIP", "IVFFlatL2"],
help="index type.",
)
parser.add_argument(
"--top_k", type=int, default=200, help="use top k search result."
)
parser.add_argument(
"--topk_across_interests",
action="store_true",
default=False,
help="select topk candidates across all interests.",
)
parser.add_argument(
"--ivf_nlist", type=int, default=1000, help="nlist of IVFFlat index."
)
parser.add_argument(
"--ivf_nprobe", type=int, default=800, help="nprobe of IVFFlat index."
)
parser.add_argument(
"--item_id_field",
type=str,
default="item_id",
help="item id field name in item embedding table.",
)
parser.add_argument(
"--item_embedding_field",
type=str,
default="item_tower_emb",
help="item embedding field name in item embedding table.",
)
parser.add_argument(
"--request_id_field",
type=str,
default="request_id",
help="request id field name in user gt table.",
)
parser.add_argument(
"--gt_items_field",
type=str,
default="gt_items",
help="gt items field name in user gt table.",
)
parser.add_argument(
"--user_embedding_field",
type=str,
default="user_tower_emb",
help="user embedding field name in user gt table.",
)
parser.add_argument(
"--num_interests",
type=int,
default=1,
help="max user embedding num for each request.",
)
parser.add_argument(
"--num_interests_field",
type=str,
default=None,
help="valid user embedding num for each request in user gt table.",
)
parser.add_argument(
"--reader_type",
type=str,
default=None,
choices=["OdpsReader", "CsvReader", "ParquetReader"],
help="input path reader type.",
)
parser.add_argument(
"--writer_type",
type=str,
default=None,
choices=["OdpsWriter", "CsvWriter", "ParquetWriter"],
help="output path writer type.",
)
parser.add_argument(
"--odps_data_quota_name",
type=str,
default="pay-as-you-go",
help="maxcompute storage api/tunnel data quota name.",
)
args, extra_args = parser.parse_known_args()
device, backend = init_process_group()
worker_id = int(os.environ.get("RANK", 0))
num_workers = int(os.environ.get("WORLD_SIZE", 1))
is_rank_zero = int(os.environ.get("RANK", 0)) == 0
is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0
selected_cols = [
args.request_id_field,
args.gt_items_field,
args.user_embedding_field,
]
if args.num_interests_field is not None:
selected_cols.append(args.num_interests_field)
reader = create_reader(
input_path=args.user_gt_input,
batch_size=args.batch_size,
selected_cols=selected_cols,
reader_type=args.reader_type,
quota_name=args.odps_data_quota_name,
)
writer_type = args.writer_type
if not writer_type:
# pyre-ignore [16]
writer_type = reader.__class__.__name__.replace("Reader", "Writer")
index, index_id_map = faiss_util.build_faiss_index(
args.item_embedding_input,
id_field=args.item_id_field,
embedding_field=args.item_embedding_field,
index_type=args.index_type,
batch_size=args.batch_size,
ivf_nlist=args.ivf_nlist,
reader_type=args.reader_type,
odps_data_quota_name=args.odps_data_quota_name,
)
index.nprobe = args.ivf_nprobe
details_writer = None
if args.hitrate_details_output:
details_writer = create_writer(
args.hitrate_details_output,
writer_type,
quota_name=args.odps_data_quota_name,
)
if args.topk_across_interests:
print("args.topk_across_interests is True")
# calculate hitrate
total_count = 0
total_hits = 0.0
total_gt_count = 0.0
for i, data in enumerate(reader.to_batches(worker_id, num_workers)):
request_id = data[args.request_id_field]
gt_items = data[args.gt_items_field]
if not pa.types.is_list(gt_items.type):
gt_items = gt_items.cast(pa.string())
gt_items = pa.compute.split_pattern(gt_items, ",")
user_emb = data[args.user_embedding_field]
user_emb_type = user_emb.type
if pa.types.is_list(user_emb_type):
if pa.types.is_list(user_emb_type.value_type):
user_emb = user_emb.values
else:
user_emb = user_emb.cast(pa.string())
if args.num_interests > 1:
user_emb = pa.compute.split_pattern(user_emb, ";").values
user_emb = pa.compute.split_pattern(user_emb, ",")
user_emb = user_emb.cast(pa.list_(pa.float32()), safe=False)
user_emb = np.stack(user_emb.to_numpy(zero_copy_only=False))
recall_distances, recall_ids = faiss_util.search_faiss_index(
index, index_id_map, user_emb, args.top_k
)
# pick topk candidates across all interests
if args.topk_across_interests:
recall_distances, recall_ids = interest_merge(
user_emb,
recall_distances,
recall_ids,
args.top_k,
args.num_interests,
args.index_type,
)
else: # pick topk candidates for each interest
recall_distances = recall_distances.reshape(
[-1, args.num_interests, recall_distances.shape[-1]]
)
recall_ids = recall_ids.reshape(
[-1, args.num_interests, recall_distances.shape[-1]]
)
num_interests_per_req = None
if args.num_interests_field:
num_interests_per_req = data[args.num_interests_field]
hitrates, hit_ids, hits, gt_count = batch_hitrate(
request_id.tolist(),
recall_ids,
gt_items.tolist(),
args.num_interests if not args.topk_across_interests else 1,
num_interests_per_req.tolist() if num_interests_per_req else None,
)
total_hits += hits
total_gt_count += gt_count
total_count += len(request_id)
if is_local_rank_zero and i % 10 == 0:
logger.info(f"Compute {total_count} hitrates...")
if details_writer:
details_writer.write(
OrderedDict(
[
("id", request_id),
(
"topk_ids",
pa.array(
recall_ids.tolist(),
type=pa.list_(pa.list_((pa.string()))),
),
),
(
"topk_dists",
pa.array(
recall_distances.tolist(),
type=pa.list_(pa.list_(pa.float32())),
),
),
("hitrate", pa.array(hitrates)),
("hit_ids", pa.array(hit_ids, type=pa.list_(pa.string()))),
]
)
)
if details_writer:
details_writer.close()
# reduce hitrate
total_hits_t = torch.tensor(total_hits, device=device)
total_gt_count_t = torch.tensor(total_gt_count, device=device)
dist.all_reduce(total_hits_t)
dist.all_reduce(total_gt_count_t)
# output hitrate
total_hitrate = (total_hits_t / total_gt_count_t).cpu().item()
if is_rank_zero:
logger.info(f"Total hitrate: {total_hitrate}")
if args.hitrate_details_output:
hitrate_writer = create_writer(
args.total_hitrate_output,
writer_type,
quota_name=args.odps_data_quota_name,
)
hitrate_writer.write({"hitrate": pa.array([total_hitrate])})
hitrate_writer.close()