easy_rec/python/hpo/generate_hpo_sql.py (63 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Called by pai_hpo.py."""
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'--sql_path', type=str, help='output sql path', default=None)
parser.add_argument(
'--config_path', type=str, help='config path', default=None)
parser.add_argument(
'--tables', type=str, help='train_table and test_table', default=None)
parser.add_argument(
'--train_tables', type=str, help='train_tables', default=None)
parser.add_argument(
'--eval_tables', type=str, help='eval_tables', default=None)
parser.add_argument(
'--cluster',
type=str,
help='specify tensorflow train jobs cluster parameter',
default=None)
parser.add_argument('--bucket', type=str, help='oss bucket', default=None)
parser.add_argument(
'--hpo_param_path', type=str, help='hpo param path', default=None)
parser.add_argument(
'--hpo_metric_save_path',
type=str,
help='hpo metric save path',
default=None)
parser.add_argument('--model_dir', type=str, help='model_dir', default=None)
parser.add_argument('--oss_host', type=str, help='oss endpoint', default=None)
parser.add_argument('--role_arn', type=str, help='role arn', default=None)
parser.add_argument(
'--algo_proj_name',
type=str,
help='algorithm project name',
default='algo_public')
parser.add_argument(
'--algo_res_proj', type=str, help='algo resource project', default=None)
parser.add_argument(
'--algo_version', type=str, help='algo version', default=None)
args = parser.parse_args()
with open(args.sql_path, 'w') as fout:
fout.write('pai -name easy_rec_ext -project %s\n' % args.algo_proj_name)
if args.algo_res_proj:
fout.write(' -Dres_project=%s\n' % args.algo_res_proj)
else:
fout.write(' -Dres_project=%s\n' % args.algo_proj_name)
if args.algo_version:
fout.write(' -Dversion=%s\n' % args.algo_version)
fout.write(' -Dconfig=%s\n' % args.config_path)
fout.write(' -Dcmd=train\n')
if args.tables:
fout.write(' -Dtables=%s\n' % args.tables)
else:
fout.write(' -Dtrain_tables=%s\n' % args.train_tables)
fout.write(' -Deval_tables=%s\n' % args.eval_tables)
fout.write(' -Dcluster=\'%s\'\n' % args.cluster)
fout.write(' -Darn=%s\n' % args.role_arn)
fout.write(' -Dbuckets=%s\n' % args.bucket)
fout.write(' -Dhpo_param_path=%s\n' % args.hpo_param_path)
fout.write(' -Dhpo_metric_save_path=%s\n' % args.hpo_metric_save_path)
fout.write(' -Dmodel_dir=%s\n' % args.model_dir)
fout.write(' -DossHost=%s\n' % args.oss_host)
fout.write(' -Deval_method=separate;\n')
print('write to %s' % args.sql_path)