tzrec/tools/tdm/retrieval.py (387 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import copy import math import os import time from collections import OrderedDict from multiprocessing import Process, Queue from threading import Thread from typing import Dict, Optional, Tuple import numpy as np import pyarrow as pa import torch from torch import distributed as dist from torch.distributed import ReduceOp from tzrec.constant import PREDICT_QUEUE_TIMEOUT, Mode from tzrec.datasets.data_parser import DataParser from tzrec.datasets.dataset import BaseWriter, create_writer from tzrec.datasets.sampler import TDMPredictSampler from tzrec.datasets.utils import Batch, RecordBatchTensor from tzrec.main import _create_features, _get_dataloader, init_process_group from tzrec.protos.data_pb2 import DatasetType from tzrec.utils import config_util from tzrec.utils.logging_util import ProgressLogger, logger def update_data( input_data: pa.RecordBatch, sampled_data: Dict[str, pa.Array] ) -> Dict[str, pa.Array]: """Update input data based on sampled data. Args: input_data (pa.RecordBatch): raw input data. sampled_data (dict): sampled data. Returns: updated data. """ item_fea_fields = sampled_data.keys() all_fea_fields = set(input_data.column_names) user_fea_fields = all_fea_fields - item_fea_fields updated_data = {} for item_fea in item_fea_fields: updated_data[item_fea] = sampled_data[item_fea] item_field_0 = list(item_fea_fields)[0] expand_num = len(sampled_data[item_field_0]) // len(input_data[item_field_0]) for user_fea in user_fea_fields: _user_fea_array = input_data[user_fea] index = np.repeat(np.arange(len(_user_fea_array)), expand_num) expand_user_fea = _user_fea_array.take(index) updated_data[user_fea] = expand_user_fea return updated_data def _tdm_predict_data_worker( sampler: TDMPredictSampler, data_parser: DataParser, first_recall_layer: int, n_cluster: int, in_queue: Queue, out_queue: Queue, is_first_layer: bool, worker_id: int, ) -> None: item_id_field = sampler._item_id_field sampler.init(worker_id) sampler.init_sampler(n_cluster) while True: record_batch_t, node_ids = in_queue.get(timeout=PREDICT_QUEUE_TIMEOUT) if record_batch_t is None: out_queue.put((None, None, None), timeout=PREDICT_QUEUE_TIMEOUT) time.sleep(10) break record_batch = record_batch_t.get() if is_first_layer: sampler.init_sampler(1) gt_node_ids = record_batch[item_id_field] cur_batch_size = len(gt_node_ids) node_ids = sampler.get({item_id_field: pa.array([-1] * cur_batch_size)})[ item_id_field ] # skip layers before first_recall_layer sampler.init_sampler(n_cluster) for _ in range(1, first_recall_layer): sampled_result_dict = sampler.get({item_id_field: node_ids}) node_ids = sampled_result_dict[item_id_field] sampled_result_dict = sampler.get({item_id_field: node_ids}) updated_inputs = update_data(record_batch, sampled_result_dict) output_data = data_parser.parse(updated_inputs) batch = data_parser.to_batch(output_data, force_no_tile=True) out_queue.put( (batch, record_batch_t, updated_inputs[item_id_field]), timeout=PREDICT_QUEUE_TIMEOUT, ) def tdm_retrieval( predict_input_path: str, predict_output_path: str, scripted_model_path: str, recall_num: int, n_cluster: int = 2, reserved_columns: Optional[str] = None, batch_size: Optional[int] = None, is_profiling: bool = False, debug_level: int = 0, dataset_type: Optional[str] = None, writer_type: Optional[str] = None, num_worker_per_level: int = 1, ) -> None: """Evaluate EasyRec TDM model. Args: predict_input_path (str): inference input data path. predict_output_path (str): inference output data path. scripted_model_path (str): path to scripted model. recall_num (int): recall item num per user. n_cluster (int): tree cluster num. reserved_columns (str, optional): columns to reserved in output. batch_size (int, optional): predict batch_size. is_profiling (bool): profiling predict process or not. debug_level (int, optional): debug level for debug parsed inputs etc. dataset_type (str, optional): dataset type, default use the type in pipeline. writer_type (int, optional): data writer type, default will be same as dataset_type in data_config. num_worker_per_level (int): num data generate worker per tree level. """ reserved_cols: Optional[list[str]] = None if reserved_columns is not None: reserved_cols = [x.strip() for x in reserved_columns.split(",")] pipeline_config = config_util.load_pipeline_config( os.path.join(scripted_model_path, "pipeline.config") ) if batch_size: pipeline_config.data_config.batch_size = batch_size if dataset_type: pipeline_config.data_config.dataset_type = getattr(DatasetType, dataset_type) device_and_backend = init_process_group() device: torch.device = device_and_backend[0] sparse_dtype: torch.dtype = torch.int32 if device.type == "cuda" else torch.int64 is_rank_zero = int(os.environ.get("RANK", 0)) == 0 is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0 data_config = pipeline_config.data_config data_config.ClearField("label_fields") data_config.drop_remainder = False # Build feature features = _create_features(list(pipeline_config.feature_configs), data_config) infer_data_config = copy.copy(data_config) infer_data_config.num_workers = 1 infer_dataloader = _get_dataloader( infer_data_config, features, predict_input_path, reserved_columns=["ALL_COLUMNS"], mode=Mode.PREDICT, debug_level=debug_level, ) infer_iterator = iter(infer_dataloader) if writer_type is None: writer_type = DatasetType.Name(data_config.dataset_type).replace( "Dataset", "Writer" ) writer: BaseWriter = create_writer( predict_output_path, writer_type, quota_name=data_config.odps_data_quota_name, ) # disable jit compile, as it compile too slow now. if "PYTORCH_TENSOREXPR_FALLBACK" not in os.environ: os.environ["PYTORCH_TENSOREXPR_FALLBACK"] = "2" model: torch.jit.ScriptModule = torch.jit.load( os.path.join(scripted_model_path, "scripted_model.pt"), map_location=device ) model.eval() if is_local_rank_zero: plogger = ProgressLogger(desc="Predicting", miniters=10) if is_profiling: if is_rank_zero: logger.info(str(model)) prof = torch.profiler.profile( schedule=torch.profiler.schedule(wait=1, warmup=1, active=10, repeat=1), on_trace_ready=torch.profiler.tensorboard_trace_handler( os.path.join(scripted_model_path, "predict_trace") ), record_shapes=True, profile_memory=True, with_stack=True, ) prof.start() parser = DataParser(features) sampler_config = pipeline_config.data_config.tdm_sampler item_id_field: str = sampler_config.item_id_field max_level: int = len(sampler_config.layer_num_sample) first_recall_layer = int(math.ceil(math.log(2 * n_cluster * recall_num, n_cluster))) dataset = infer_dataloader.dataset # pyre-ignore [16] fields = dataset.input_fields # pyre-ignore [29] predict_sampler = TDMPredictSampler( sampler_config, fields, batch_size, is_training=False ) predict_sampler.init_cluster( num_client_per_rank=(max_level - first_recall_layer) * num_worker_per_level ) predict_sampler.launch_server() num_class = pipeline_config.model_config.num_class pos_prob_name: str = "probs1" if num_class == 2 else "probs" def _forward( batch: Batch, record_batch_t: RecordBatchTensor, node_ids: pa.Array, layer_id: int, ) -> Tuple[RecordBatchTensor, pa.Array]: with torch.no_grad(): parsed_inputs = batch.to_dict(sparse_dtype=sparse_dtype) # when predicting with a model exported using INPUT_TILE, # we set the batch size tensor to 1 to disable tiling. parsed_inputs["batch_size"] = torch.tensor(1, dtype=torch.int64) predictions = model(parsed_inputs, device) gt_node_ids = record_batch_t.get()[item_id_field] cur_batch_size = len(gt_node_ids) probs = predictions[pos_prob_name].reshape(cur_batch_size, -1) if layer_id == max_level - 1: k = recall_num candidate_ids = node_ids.to_numpy(zero_copy_only=False).reshape( cur_batch_size, -1 ) sort_prob_index = torch.argsort(-probs, dim=1).cpu().numpy() sort_cand_ids = np.take_along_axis( candidate_ids, sort_prob_index, axis=1 ) node_ids = [] for i in range(cur_batch_size): _, unique_indices = np.unique(sort_cand_ids[i], return_index=True) node_ids.append( np.take(sort_cand_ids[i], np.sort(unique_indices)[:k]) ) node_ids = pa.array(node_ids) else: k = 2 * recall_num _, topk_indices_in_group = torch.topk(probs, k, dim=1) topk_indices = topk_indices_in_group + torch.arange( cur_batch_size, device=device ).unsqueeze(1) * probs.size(1) topk_indices = topk_indices.reshape(-1).cpu().numpy() node_ids = node_ids.take(topk_indices) return record_batch_t, node_ids def _forward_loop(data_queue: Queue, pred_queue: Queue, layer_id: int) -> None: stop_cnt = 0 while True: batch, record_batch_t, node_ids = data_queue.get( timeout=PREDICT_QUEUE_TIMEOUT ) if batch is None: stop_cnt += 1 if stop_cnt == num_worker_per_level: for _ in range(num_worker_per_level): pred_queue.put((None, None), timeout=PREDICT_QUEUE_TIMEOUT) break else: continue assert batch is not None pred = _forward(batch, record_batch_t, node_ids, layer_id) pred_queue.put(pred, timeout=PREDICT_QUEUE_TIMEOUT) def _write_loop(pred_queue: Queue, metric_queue: Queue) -> None: total = 0 recall = 0 while True: record_batch_t, node_ids = pred_queue.get(timeout=PREDICT_QUEUE_TIMEOUT) if record_batch_t is None: break output_dict = OrderedDict() reserve_batch_record = record_batch_t.get() gt_node_ids = reserve_batch_record[item_id_field] cur_batch_size = len(gt_node_ids) if reserved_cols is not None: for c in reserved_cols: output_dict[c] = reserve_batch_record[c] output_dict["recall_ids"] = node_ids writer.write(output_dict) # calculate precision and recall retrieval_result = np.any( np.equal( gt_node_ids.to_numpy(zero_copy_only=False)[:, None], np.array(node_ids.to_pylist()), ), axis=1, ) total += cur_batch_size recall += np.sum(retrieval_result) metric_queue.put((total, recall), timeout=PREDICT_QUEUE_TIMEOUT) in_queues = [Queue(maxsize=2) for _ in range(max_level - first_recall_layer + 1)] out_queues = [Queue(maxsize=2) for _ in range(max_level - first_recall_layer)] metric_queue = Queue(maxsize=1) data_p_list = [] for i in range(max_level - first_recall_layer): for j in range(num_worker_per_level): p = Process( target=_tdm_predict_data_worker, args=( predict_sampler, parser, first_recall_layer, n_cluster, in_queues[i], out_queues[i], i == 0, i * num_worker_per_level + j, ), ) p.start() data_p_list.append(p) forward_t_list = [] for i in range(max_level - first_recall_layer): t = Thread( target=_forward_loop, args=(out_queues[i], in_queues[i + 1], i + first_recall_layer), ) t.start() forward_t_list.append(t) write_t = Thread( target=_write_loop, args=(in_queues[len(in_queues) - 1], metric_queue) ) write_t.start() i_step = 0 while True: try: batch = next(infer_iterator) in_queues[0].put((batch.reserves, None), timeout=PREDICT_QUEUE_TIMEOUT) if is_local_rank_zero: plogger.log(i_step) if is_profiling: prof.step() i_step += 1 except StopIteration: break for _ in range(num_worker_per_level): in_queues[0].put((None, None), timeout=PREDICT_QUEUE_TIMEOUT) for p in data_p_list: p.join() for t in forward_t_list: t.join() write_t.join() writer.close() total, recall = metric_queue.get(timeout=PREDICT_QUEUE_TIMEOUT) total_t = torch.tensor(total, device=device) recall_t = torch.tensor(recall, device=device) dist.all_reduce(total_t, op=ReduceOp.SUM) dist.all_reduce(recall_t, op=ReduceOp.SUM) # pyre-ignore [6] recall_ratio = recall_t.cpu().item() / total_t.cpu().item() if is_profiling: prof.stop() if is_rank_zero: logger.info(f"Retrieval Finished. Recall:{recall_ratio}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--scripted_model_path", type=str, default=None, help="scripted model to be evaled, if not specified, use the checkpoint", ) parser.add_argument( "--predict_input_path", type=str, default=None, help="inference data input path", ) parser.add_argument( "--predict_output_path", type=str, default=None, help="inference data output path", ) parser.add_argument( "--reserved_columns", type=str, default=None, help="column names to reserved in output", ) parser.add_argument( "--batch_size", type=int, default=None, help="predict batch size, default will use batch size in config.", ) parser.add_argument( "--is_profiling", action="store_true", default=False, help="profiling predict progress.", ) parser.add_argument( "--debug_level", type=int, default=0, help="debug level for debug parsed inputs etc.", ) parser.add_argument( "--dataset_type", type=str, default=None, help="dataset type, default will use dataset type in config.", ) parser.add_argument( "--recall_num", type=int, default=200, help="recall item num per user." ) parser.add_argument("--n_cluster", type=int, default=2, help="tree cluster num.") parser.add_argument( "--num_worker_per_level", type=int, default=1, help="num data generate worker per tree level.", ) args, extra_args = parser.parse_known_args() tdm_retrieval( predict_input_path=args.predict_input_path, predict_output_path=args.predict_output_path, scripted_model_path=args.scripted_model_path, recall_num=args.recall_num, n_cluster=args.n_cluster, reserved_columns=args.reserved_columns, batch_size=args.batch_size, is_profiling=args.is_profiling, debug_level=args.debug_level, dataset_type=args.dataset_type, num_worker_per_level=args.num_worker_per_level, )