easy_rec/python/tools/convert_rtp_fg.py (94 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Convert rtp fg feature config to easy_rec data_config and feature_config."""
import argparse
import logging
import sys
import tensorflow as tf
from easy_rec.python.utils.config_util import save_message
from easy_rec.python.utils.convert_rtp_fg import convert_rtp_fg
logging.basicConfig(
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
level=logging.INFO)
if tf.__version__ >= '2.0':
tf = tf.compat.v1
model_types = ['deepfm', 'multi_tower', 'wide_and_deep', 'esmm', 'dbmtl', '']
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_type',
type=str,
choices=model_types,
default='',
help='model type, currently support: %s' % ','.join(model_types))
parser.add_argument('--rtp_fg', type=str, help='rtp fg path')
parser.add_argument(
'--embedding_dim', type=int, default=16, help='embedding_dimension')
parser.add_argument(
'--batch_size', type=int, default=1024, help='batch_size for train')
parser.add_argument(
'--label',
type=str,
default='',
nargs='+',
required=True,
help='label fields')
parser.add_argument(
'--num_steps',
type=int,
default=1000,
help='number of train steps = num_samples * num_epochs / batch_size / num_workers'
)
parser.add_argument('--output_path', type=str, help='generated config path')
parser.add_argument(
'--incol_separator',
type=str,
default='\003',
help='separator for multi_value features')
parser.add_argument(
'--separator',
type=str,
default='\002',
help='separator between different features')
parser.add_argument(
'--train_input_path', type=str, default=None, help='train data path')
parser.add_argument(
'--eval_input_path', type=str, default=None, help='eval data path')
parser.add_argument(
'--selected_cols',
type=str,
default=None,
help='selected cols, for csv input, it is in the format of: label_col_id0,...,lable_cold_idn,feature_col_id '
'for odps table input, it is in the format of: label_col_name0,...,label_col_namen,feature_col_name '
)
parser.add_argument(
'--rtp_separator', type=str, default=';', help='separator')
parser.add_argument(
'--input_type',
type=str,
default='OdpsRTPInput',
help='default to OdpsRTPInput, if test local, change it to RTPInput')
parser.add_argument(
'--is_async', action='store_true', help='async mode, debug to false')
args = parser.parse_args()
if not args.rtp_fg:
logging.error('rtp_fg is not set')
sys.exit(1)
if not args.output_path:
logging.error('output_path is not set')
sys.exit(1)
pipeline_config = convert_rtp_fg(args.rtp_fg, args.embedding_dim,
args.batch_size, args.label, args.num_steps,
args.model_type, args.separator,
args.incol_separator, args.train_input_path,
args.eval_input_path, args.selected_cols,
args.input_type, args.is_async)
save_message(pipeline_config, args.output_path)
logging.info('Conversion done.')
logging.info('Tips:')
logging.info(
'if run on local, please change data_config.input_type to RTPInput, '
'and model_dir/train_input_path/eval_input_path must also be set, ')
logging.info(
'if run local, please set data_config.selected_cols in the format '
'label_col_id0,label_col_id1,...,label_col_idn,feature_col_id')
logging.info(
'if run on odps, selected_cols must be set, which are label0_col,'
'label1_col, ..., feature_col_name')