easy_rec/python/tools/convert_rtp_data.py (57 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Convert the original rtp data format to csv format.
The original data format is not suggested to use with EasyRec.
In the original format: features are in kv format, if a feature has
more than one value, there will be multiple kvs, such as:
...tagbeautytagsmart...
In our new format:
...beautysmart...
"""
import argparse
import csv
import json
import logging
import sys
import tensorflow as tf
logging.basicConfig(
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
level=logging.INFO)
if tf.__version__ >= '2.0':
tf = tf.compat.v1
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--rtp_fg', type=str, default='', help='rtp fg path(.json)')
parser.add_argument('--input_path', type=str, default='', help='input path')
parser.add_argument('--output_path', type=str, default='', help='output path')
parser.add_argument('--label', type=str, default='', help='label for train')
args = parser.parse_args()
if not args.rtp_fg:
logging.error('rtp_fg is not set')
sys.exit(1)
if not args.input_path:
logging.error('input_path is not set')
sys.exit(1)
if not args.output_path:
logging.error('output_path is not set')
sys.exit(1)
if not args.label:
logging.error('label is not set')
sys.exit(1)
with open(args.rtp_fg, 'r') as fin:
rtp_fg = json.load(fin)
feature_names = [args.label]
for feature in rtp_fg['features']:
feature_name = feature['feature_name']
feature_names.append(feature_name)
with open(args.input_path, 'r') as fin:
with open(args.output_path, 'w') as fout:
writer = csv.writer(fout)
for line_str in fin:
line_str = line_str.strip()
line_toks = line_str.split('\002')
temp_dict = {}
for line_tok in line_toks:
k, v = line_tok.split('\003')
if k not in temp_dict:
temp_dict[k] = [v]
else:
temp_dict[k].append(v)
temp_vs = []
for feature_name in feature_names:
if feature_name in temp_dict:
temp_vs.append('|'.join(temp_dict[feature_name]))
else:
temp_vs.append('')
writer.writerow(temp_vs)