easy_rec/python/tools/hit_rate_ds.py (161 lines of code) (raw):
# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
# """Evaluation of Top k hitrate."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import logging
import os
import sys
import graphlearn as gl
import tensorflow as tf
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
from easy_rec.python.utils import config_util
from easy_rec.python.utils import io_util
from easy_rec.python.utils.config_util import process_multi_file_input_path
from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch
from easy_rec.python.utils.hit_rate_utils import load_graph
from easy_rec.python.utils.hit_rate_utils import reduce_hitrate
from easy_rec.python.utils.hive_utils import HiveUtils
if tf.__version__ >= '2.0':
tf = tf.compat.v1
from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num_on_ds # NOQA
logging.basicConfig(
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
level=logging.INFO)
tf.app.flags.DEFINE_string('item_emb_table', '', 'item embedding table name')
tf.app.flags.DEFINE_string('gt_table', '', 'ground truth table name')
tf.app.flags.DEFINE_string('hitrate_details_result', '',
'hitrate detail file path')
tf.app.flags.DEFINE_string('total_hitrate_result', '',
'total hitrate result file path')
tf.app.flags.DEFINE_string('pipeline_config_path', '', 'pipeline config path')
tf.app.flags.DEFINE_integer('batch_size', 512, 'batch size')
tf.app.flags.DEFINE_integer('emb_dim', 128, 'embedding dimension')
tf.app.flags.DEFINE_string('recall_type', 'i2i', 'i2i or u2i')
tf.app.flags.DEFINE_integer('top_k', '5', 'top_k hitrate.')
tf.app.flags.DEFINE_integer('knn_metric', '0', '0(l2) or 1(ip).')
tf.app.flags.DEFINE_bool('knn_strict', False, 'use exact search.')
tf.app.flags.DEFINE_integer('timeout', '60', 'timeout')
tf.app.flags.DEFINE_integer('num_interests', 1, 'max number of interests')
tf.app.flags.DEFINE_string('gt_table_field_sep', '\t', 'gt_table_field_sep')
tf.app.flags.DEFINE_string('item_emb_table_field_sep', '\t',
'item_emb_table_field_sep')
tf.app.flags.DEFINE_bool('is_on_ds', False, help='is on ds')
FLAGS = tf.app.flags.FLAGS
def compute_hitrate(g, gt_all, hitrate_writer, gt_table=None):
"""Compute hitrate of each worker.
Args:
g: a GL Graph instance.
gt_reader: reader of input trigger_items_table.
hitrate_writer: writer of hitrate table.
gt_table: ground truth table.
Returns:
total_hits: total hits of this worker.
total_gt_count: total count of ground truth items of this worker.
"""
total_hits = 0.0
total_gt_count = 0.0
for gt_record in gt_all:
gt_record = list(gt_record)
hits, gt_count, src_ids, recall_ids, recall_distances, hitrates, bad_cases, bad_dists = \
compute_hitrate_batch(g, gt_record, FLAGS.emb_dim, FLAGS.num_interests, FLAGS.top_k)
total_hits += hits
total_gt_count += gt_count
src_ids = [str(ids) for ids in src_ids]
hitrates = [str(hitrate) for hitrate in hitrates]
topk_recalls = [','.join(str(x) for x in ids) for ids in recall_ids]
topk_dists = [
','.join('|'.join(str(x)
for x in dist)
for dist in dists)
for dists in recall_distances
]
bad_cases = [','.join(str(x) for x in bad_case) for bad_case in bad_cases]
bad_dists = [','.join(str(x) for x in dist) for dist in bad_dists]
hitrate_writer.write('\n'.join([
'\t'.join(line) for line in zip(src_ids, topk_recalls, topk_dists,
hitrates, bad_cases, bad_dists)
]))
print('total_hits: ', total_hits)
print('total_gt_count: ', total_gt_count)
return total_hits, total_gt_count
def gt_hdfs(gt_table, batch_size, gt_file_sep):
if '*' in gt_table or ',' in gt_table:
file_paths = tf.gfile.Glob(gt_table.split(','))
elif tf.gfile.IsDirectory(gt_table):
file_paths = tf.gfile.Glob(os.path.join(gt_table, '*'))
else:
file_paths = tf.gfile.Glob(gt_table)
batch_list, i = [], 0
for file_path in file_paths:
with tf.gfile.GFile(file_path, 'r') as fin:
for gt in fin:
i += 1
gt_list = gt.strip().split(gt_file_sep)
# make id , emb_num to int
gt_list[0], gt_list[3] = int(gt_list[0]), int(gt_list[3])
batch_list.append(tuple(i for i in gt_list))
if i >= batch_size:
yield batch_list
batch_list, i = [], 0
if i != 0:
yield batch_list
def main():
tf_config = json.loads(os.environ['TF_CONFIG'])
worker_count = len(tf_config['cluster']['worker'])
task_index = tf_config['task']['index']
job_name = tf_config['task']['type']
hitrate_details_result = FLAGS.hitrate_details_result
total_hitrate_result = FLAGS.total_hitrate_result
i_emb_table = FLAGS.item_emb_table
gt_table = FLAGS.gt_table
pipeline_config = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path)
logging.info('i_emb_table %s', i_emb_table)
input_type = pipeline_config.data_config.input_type
input_type_name = DatasetConfig.InputType.Name(input_type)
if input_type_name == 'CSVInput':
i_emb_table = process_multi_file_input_path(i_emb_table)
else:
hive_utils = HiveUtils(
data_config=pipeline_config.data_config,
hive_config=pipeline_config.hive_train_input)
i_emb_table = hive_utils.get_table_location(i_emb_table)
g = load_graph(i_emb_table, FLAGS.emb_dim, FLAGS.knn_metric, FLAGS.timeout,
FLAGS.knn_strict)
gl.set_tracker_mode(0)
gl.set_field_delimiter(FLAGS.item_emb_table_field_sep)
cluster = tf.train.ClusterSpec({
'ps': tf_config['cluster']['ps'],
'worker': tf_config['cluster']['worker']
})
server = tf.train.Server(cluster, job_name=job_name, task_index=task_index)
if job_name == 'ps':
server.join()
else:
worker_hosts = [
str(host.split(':')[0]) + ':888' + str(i)
for i, host in enumerate(tf_config['cluster']['worker'])
]
worker_hosts = ','.join(worker_hosts)
g.init(task_index=task_index, task_count=worker_count, hosts=worker_hosts)
# Your model, use g to do some operation, such as sampling
if input_type_name == 'CSVInput':
gt_all = gt_hdfs(gt_table, FLAGS.batch_size, FLAGS.gt_table_field_sep)
else:
gt_reader = HiveUtils(
data_config=pipeline_config.data_config,
hive_config=pipeline_config.hive_train_input,
selected_cols='*')
gt_all = gt_reader.hive_read_lines(gt_table, FLAGS.batch_size)
if not tf.gfile.IsDirectory(hitrate_details_result):
tf.gfile.MakeDirs(hitrate_details_result)
hitrate_details_result = os.path.join(hitrate_details_result,
'part-%s' % task_index)
details_writer = tf.gfile.GFile(hitrate_details_result, 'w')
print('Start compute hitrate...')
total_hits, total_gt_count = compute_hitrate(g, gt_all, details_writer,
gt_table)
var_total_hitrate, var_worker_count = reduce_hitrate(
cluster, total_hits, total_gt_count, task_index)
with tf.train.MonitoredTrainingSession(
master=server.target, is_chief=(task_index == 0)) as sess:
outs = sess.run([var_total_hitrate, var_worker_count])
# write after all workers have completed the calculation of hitrate.
print('outs: ', outs)
if outs[1] == worker_count:
logging.info(outs)
with tf.gfile.GFile(total_hitrate_result, 'w') as total_writer:
total_writer.write(str(outs[0]))
details_writer.close()
g.close()
print('Compute hitrate done.')
if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
main()