easy_rec/python/tools/hit_rate_pai.py (96 lines of code) (raw):

# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= """Evaluation of Top k hitrate.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import sys import tensorflow as tf from easy_rec.python.utils import io_util from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch from easy_rec.python.utils.hit_rate_utils import load_graph from easy_rec.python.utils.hit_rate_utils import reduce_hitrate flags = tf.app.flags FLAGS = flags.FLAGS flags.DEFINE_integer('task_index', None, 'Task index') flags.DEFINE_integer('task_count', None, 'Task count') flags.DEFINE_string('job_name', None, 'worker or ps or aligraph') flags.DEFINE_string('ps_hosts', '', 'ps hosts') flags.DEFINE_string('worker_hosts', '', 'worker hosts') flags.DEFINE_string('tables', '', 'input odps tables name') flags.DEFINE_string('outputs', '', 'ouput odps tables name') flags.DEFINE_integer('batch_size', 512, 'batch size') flags.DEFINE_integer('emb_dim', 128, 'embedding dimension') flags.DEFINE_string('recall_type', 'i2i', 'i2i or u2i') flags.DEFINE_integer('top_k', '5', 'top_k hitrate.') flags.DEFINE_integer('knn_metric', '0', '0(l2) or 1(ip).') flags.DEFINE_bool('knn_strict', False, 'use exact search.') flags.DEFINE_integer('timeout', '60', 'timeout') flags.DEFINE_integer('num_interests', 1, 'max number of interests') def compute_hitrate(g, gt_reader, hitrate_writer): """Compute hitrate of each worker. Args: g: a GL Graph instance. gt_reader: odps reader of input trigger_items_table. hitrate_writer: odps writer of hitrate table. Returns: total_hits: total hits of this worker. total_gt_count: total count of ground truth items of this worker. """ total_hits = 0.0 total_gt_count = 0.0 while True: try: gt_record = gt_reader.read(FLAGS.batch_size) hits, gt_count, src_ids, recall_ids, recall_distances, hitrates, bad_cases, bad_dists = \ compute_hitrate_batch(g, gt_record, FLAGS.emb_dim, FLAGS.num_interests, FLAGS.top_k) total_hits += hits total_gt_count += gt_count topk_recalls = [','.join(str(x) for x in ids) for ids in recall_ids] topk_dists = [ ','.join(str(x) for x in dists) for dists in recall_distances ] bad_cases = [','.join(str(x) for x in case) for case in bad_cases] bad_dists = [','.join(str(x) for x in dist) for dist in bad_dists] hitrate_writer.write( list( zip(src_ids, topk_recalls, topk_dists, hitrates, bad_cases, bad_dists)), indices=[0, 1, 2, 3, 4, 5]) except tf.python_io.OutOfRangeException: break return total_hits, total_gt_count def main(): worker_count = len(FLAGS.worker_hosts.split(',')) input_tables = FLAGS.tables.split(',') if FLAGS.recall_type == 'u2i': i_emb_table, gt_table = input_tables g = load_graph(i_emb_table, FLAGS.emb_dim, FLAGS.knn_metric, FLAGS.timeout, FLAGS.knn_strict) else: i_emb_table, gt_table = input_tables[-2], input_tables[-1] g = load_graph(i_emb_table, FLAGS.emb_dim, FLAGS.knn_metric, FLAGS.timeout, FLAGS.knn_strict) hitrate_details_table, total_hitrate_table = FLAGS.outputs.split(',') cluster = tf.train.ClusterSpec({ 'ps': FLAGS.ps_hosts.split(','), 'worker': FLAGS.worker_hosts.split(',') }) server = tf.train.Server( cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == 'ps': server.join() else: g.init(task_index=FLAGS.task_index, task_count=worker_count) gt_reader = tf.python_io.TableReader( gt_table, slice_id=FLAGS.task_index, slice_count=worker_count, capacity=2048) details_writer = tf.python_io.TableWriter( hitrate_details_table, slice_id=FLAGS.task_index) print('Start compute hitrate...') total_hits, total_gt_count = compute_hitrate(g, gt_reader, details_writer) var_total_hitrate, var_worker_count = reduce_hitrate( cluster, total_hits, total_gt_count, FLAGS.task_index) with tf.train.MonitoredTrainingSession( master=server.target, is_chief=(FLAGS.task_index == 0)) as sess: outs = sess.run([var_total_hitrate, var_worker_count]) # write after all workers have completed the calculation of hitrate. if outs[1] == worker_count: with tf.python_io.TableWriter(total_hitrate_table) as total_writer: total_writer.write([outs[0]], indices=[0]) gt_reader.close() details_writer.close() g.close() print('Compute hitrate done.') if __name__ == '__main__': sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv) main()