easy_rec/python/tools/hit_rate_ds.py (161 lines of code) (raw):

# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= # """Evaluation of Top k hitrate.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import json import logging import os import sys import graphlearn as gl import tensorflow as tf from easy_rec.python.protos.dataset_pb2 import DatasetConfig from easy_rec.python.utils import config_util from easy_rec.python.utils import io_util from easy_rec.python.utils.config_util import process_multi_file_input_path from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch from easy_rec.python.utils.hit_rate_utils import load_graph from easy_rec.python.utils.hit_rate_utils import reduce_hitrate from easy_rec.python.utils.hive_utils import HiveUtils if tf.__version__ >= '2.0': tf = tf.compat.v1 from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num_on_ds # NOQA logging.basicConfig( format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s', level=logging.INFO) tf.app.flags.DEFINE_string('item_emb_table', '', 'item embedding table name') tf.app.flags.DEFINE_string('gt_table', '', 'ground truth table name') tf.app.flags.DEFINE_string('hitrate_details_result', '', 'hitrate detail file path') tf.app.flags.DEFINE_string('total_hitrate_result', '', 'total hitrate result file path') tf.app.flags.DEFINE_string('pipeline_config_path', '', 'pipeline config path') tf.app.flags.DEFINE_integer('batch_size', 512, 'batch size') tf.app.flags.DEFINE_integer('emb_dim', 128, 'embedding dimension') tf.app.flags.DEFINE_string('recall_type', 'i2i', 'i2i or u2i') tf.app.flags.DEFINE_integer('top_k', '5', 'top_k hitrate.') tf.app.flags.DEFINE_integer('knn_metric', '0', '0(l2) or 1(ip).') tf.app.flags.DEFINE_bool('knn_strict', False, 'use exact search.') tf.app.flags.DEFINE_integer('timeout', '60', 'timeout') tf.app.flags.DEFINE_integer('num_interests', 1, 'max number of interests') tf.app.flags.DEFINE_string('gt_table_field_sep', '\t', 'gt_table_field_sep') tf.app.flags.DEFINE_string('item_emb_table_field_sep', '\t', 'item_emb_table_field_sep') tf.app.flags.DEFINE_bool('is_on_ds', False, help='is on ds') FLAGS = tf.app.flags.FLAGS def compute_hitrate(g, gt_all, hitrate_writer, gt_table=None): """Compute hitrate of each worker. Args: g: a GL Graph instance. gt_reader: reader of input trigger_items_table. hitrate_writer: writer of hitrate table. gt_table: ground truth table. Returns: total_hits: total hits of this worker. total_gt_count: total count of ground truth items of this worker. """ total_hits = 0.0 total_gt_count = 0.0 for gt_record in gt_all: gt_record = list(gt_record) hits, gt_count, src_ids, recall_ids, recall_distances, hitrates, bad_cases, bad_dists = \ compute_hitrate_batch(g, gt_record, FLAGS.emb_dim, FLAGS.num_interests, FLAGS.top_k) total_hits += hits total_gt_count += gt_count src_ids = [str(ids) for ids in src_ids] hitrates = [str(hitrate) for hitrate in hitrates] topk_recalls = [','.join(str(x) for x in ids) for ids in recall_ids] topk_dists = [ ','.join('|'.join(str(x) for x in dist) for dist in dists) for dists in recall_distances ] bad_cases = [','.join(str(x) for x in bad_case) for bad_case in bad_cases] bad_dists = [','.join(str(x) for x in dist) for dist in bad_dists] hitrate_writer.write('\n'.join([ '\t'.join(line) for line in zip(src_ids, topk_recalls, topk_dists, hitrates, bad_cases, bad_dists) ])) print('total_hits: ', total_hits) print('total_gt_count: ', total_gt_count) return total_hits, total_gt_count def gt_hdfs(gt_table, batch_size, gt_file_sep): if '*' in gt_table or ',' in gt_table: file_paths = tf.gfile.Glob(gt_table.split(',')) elif tf.gfile.IsDirectory(gt_table): file_paths = tf.gfile.Glob(os.path.join(gt_table, '*')) else: file_paths = tf.gfile.Glob(gt_table) batch_list, i = [], 0 for file_path in file_paths: with tf.gfile.GFile(file_path, 'r') as fin: for gt in fin: i += 1 gt_list = gt.strip().split(gt_file_sep) # make id , emb_num to int gt_list[0], gt_list[3] = int(gt_list[0]), int(gt_list[3]) batch_list.append(tuple(i for i in gt_list)) if i >= batch_size: yield batch_list batch_list, i = [], 0 if i != 0: yield batch_list def main(): tf_config = json.loads(os.environ['TF_CONFIG']) worker_count = len(tf_config['cluster']['worker']) task_index = tf_config['task']['index'] job_name = tf_config['task']['type'] hitrate_details_result = FLAGS.hitrate_details_result total_hitrate_result = FLAGS.total_hitrate_result i_emb_table = FLAGS.item_emb_table gt_table = FLAGS.gt_table pipeline_config = config_util.get_configs_from_pipeline_file( FLAGS.pipeline_config_path) logging.info('i_emb_table %s', i_emb_table) input_type = pipeline_config.data_config.input_type input_type_name = DatasetConfig.InputType.Name(input_type) if input_type_name == 'CSVInput': i_emb_table = process_multi_file_input_path(i_emb_table) else: hive_utils = HiveUtils( data_config=pipeline_config.data_config, hive_config=pipeline_config.hive_train_input) i_emb_table = hive_utils.get_table_location(i_emb_table) g = load_graph(i_emb_table, FLAGS.emb_dim, FLAGS.knn_metric, FLAGS.timeout, FLAGS.knn_strict) gl.set_tracker_mode(0) gl.set_field_delimiter(FLAGS.item_emb_table_field_sep) cluster = tf.train.ClusterSpec({ 'ps': tf_config['cluster']['ps'], 'worker': tf_config['cluster']['worker'] }) server = tf.train.Server(cluster, job_name=job_name, task_index=task_index) if job_name == 'ps': server.join() else: worker_hosts = [ str(host.split(':')[0]) + ':888' + str(i) for i, host in enumerate(tf_config['cluster']['worker']) ] worker_hosts = ','.join(worker_hosts) g.init(task_index=task_index, task_count=worker_count, hosts=worker_hosts) # Your model, use g to do some operation, such as sampling if input_type_name == 'CSVInput': gt_all = gt_hdfs(gt_table, FLAGS.batch_size, FLAGS.gt_table_field_sep) else: gt_reader = HiveUtils( data_config=pipeline_config.data_config, hive_config=pipeline_config.hive_train_input, selected_cols='*') gt_all = gt_reader.hive_read_lines(gt_table, FLAGS.batch_size) if not tf.gfile.IsDirectory(hitrate_details_result): tf.gfile.MakeDirs(hitrate_details_result) hitrate_details_result = os.path.join(hitrate_details_result, 'part-%s' % task_index) details_writer = tf.gfile.GFile(hitrate_details_result, 'w') print('Start compute hitrate...') total_hits, total_gt_count = compute_hitrate(g, gt_all, details_writer, gt_table) var_total_hitrate, var_worker_count = reduce_hitrate( cluster, total_hits, total_gt_count, task_index) with tf.train.MonitoredTrainingSession( master=server.target, is_chief=(task_index == 0)) as sess: outs = sess.run([var_total_hitrate, var_worker_count]) # write after all workers have completed the calculation of hitrate. print('outs: ', outs) if outs[1] == worker_count: logging.info(outs) with tf.gfile.GFile(total_hitrate_result, 'w') as total_writer: total_writer.write(str(outs[0])) details_writer.close() g.close() print('Compute hitrate done.') if __name__ == '__main__': sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv) main()