easy_rec/python/hpo/pai_hpo.py (251 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Hyperparameter search demo for easy_rec on pai."""
import json
import logging
import os
import shutil
import time
from pai.automl import hpo
from easy_rec.python.utils import hpo_util
file_dir, _ = os.path.split(os.path.abspath(__file__))
logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
try:
import subprocess
subprocess.check_output('which odpscmd', shell=True)
except Exception:
logging.error(
'odpscmd is not in path, please install from https://help.aliyun.com/document_detail/27971.html'
)
def get_tuner(data, max_parallel, max_trial_num):
param_dict = json.loads(data)
if 'environment' in param_dict.keys():
hpo.register_env(**param_dict['environment'])
# hyper param
params = []
for h in param_dict['hyperparams']:
param = hpo.hyperparam.create(**h)
params.append(param)
# tasks
tasks = []
for t in param_dict['tasks']:
r = None
if 'metric_reader' in t.keys():
r = hpo.reader.create(**t['metric_reader'])
t.pop('metric_reader')
if r:
subtask = hpo.task.create(metric_reader=r, **t)
else:
subtask = hpo.task.create(**t)
tasks.append(subtask)
# earlystop & algo
early_stop = None
if 'earlystop' in param_dict.keys():
early_stop = hpo.earlystop.create(**param_dict['earlystop'])
algo = None
if 'algorithm' in param_dict.keys():
algo = hpo.algorithm.create(**param_dict['algorithm'])
tuner = hpo.autotuner.AutoTuner(
earlystop=early_stop,
algorithm=algo,
hyperparams=params,
task_list=tasks,
max_parallel=max_parallel,
max_trial_num=max_trial_num,
mode='local',
user_id='your_cloud_id')
return tuner
def hpo_config(config_path, hyperparams, environment, exp_dir, tables,
train_tables, eval_tables, cluster, algo_proj_name,
algo_res_proj, algo_version, metric_name, odps_config_path):
earlystop = {'type': 'large_is_better', 'max_runtime': 3600 * 12}
algorithm = {
'type': 'gp',
'initial_trials_num': 4,
'stop_when_exception': True
}
if exp_dir.startswith('oss://'):
exp_dir = exp_dir.replace('oss://', '')
exp_dir = exp_dir[exp_dir.find('/') + 1:]
param_path = '%s/hpo_test_{{ trial.id }}.json' % exp_dir
metric_path = '%s/easy_rec_hpo_{{ trial.id }}.metric' % exp_dir
model_path = '%s/easy_rec_hpo_{{ trial.id }}' % exp_dir
bucket = 'oss://' + environment['bucket'].strip('/') + '/'
adapter_task = {
'type': 'ossadaptertask',
# hpo_param_path for easy_rec
'param_file': param_path,
}
tmp_dir = '/tmp/pai_easy_rec_hpo_%d' % time.time()
os.makedirs(tmp_dir)
logging.info('local temporary path: %s' % tmp_dir)
def _add_prefix(table_name):
table_name = table_name.strip()
if not table_name.startswith('odps://'):
return 'odps://%s/tables/%s' % (environment['project'], table_name)
else:
return table_name
if tables:
tables = [_add_prefix(x) for x in tables.split(',') if x != '']
tables = ','.join(tables)
logging.info('will tune on data: %s' % tables)
else:
train_tables = [_add_prefix(x) for x in train_tables.split(',') if x != '']
train_tables = ','.join(train_tables)
eval_tables = [_add_prefix(x) for x in eval_tables.split(',') if x != '']
eval_tables = ','.join(eval_tables)
sql_path = '%s/train_ext_hpo_{{ trial.id }}.sql' % tmp_dir
cmd_args = [
'python', '-m', 'easy_rec.python.hpo.generate_hpo_sql', '--sql_path',
sql_path, '--config_path', config_path, '--cluster', cluster, '--bucket',
bucket, '--hpo_param_path',
os.path.join(bucket, param_path), '--hpo_metric_save_path',
os.path.join(bucket, metric_path), '--model_dir',
os.path.join(bucket,
model_path), '--oss_host', environment['oss_endpoint'],
'--role_arn', environment['role_arn'], '--algo_proj_name', algo_proj_name
]
if tables:
cmd_args.extend(['--tables', tables])
if train_tables and eval_tables:
cmd_args.extend(
['--train_tables', train_tables, '--eval_tables', eval_tables])
if algo_res_proj:
cmd_args.extend(['--algo_res_proj', algo_res_proj])
if algo_version:
cmd_args.extend(['--algo_version', algo_version])
prepare_sql_task = {'type': 'BashTask', 'cmd': cmd_args}
train_task = {
'type': 'BashTask',
'cmd': ['odpscmd',
'--config=%s' % odps_config_path, '-f', sql_path],
'metric_reader': {
'type': 'oss_reader',
'location': metric_path,
'parser_pattern': '.*"%s": (\\d.\\d+).*' % metric_name
}
}
tasks = [adapter_task, prepare_sql_task, train_task]
data = {
'earlystop': earlystop,
'algorithm': algorithm,
'hyperparams': hyperparams,
'tasks': tasks,
'environment': environment
}
return data, tmp_dir
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'--odps_config', type=str, help='odps_config.ini', default=None)
parser.add_argument(
'--oss_config', type=str, help='excel config path', default='')
parser.add_argument('--bucket', type=str, help='bucket name', default=None)
parser.add_argument('--role_arn', type=str, help='role arn', default=None)
parser.add_argument(
'--hyperparams', type=str, help='hyper parameters', default=None)
parser.add_argument(
'--config_path', type=str, help='pipeline config', default=None)
parser.add_argument(
'--tables', type=str, help='train table and test table', default=None)
parser.add_argument(
'--train_tables', type=str, help='train tables', default=None)
parser.add_argument(
'--eval_tables', type=str, help='eval tables', default=None)
parser.add_argument(
'--exp_dir', type=str, help='hpo experiment directory', default=None)
parser.add_argument(
'--cluster',
type=str,
help='cluster spec',
default='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":3, "cpu":1000, "gpu":100, "memory":40000}}'
)
parser.add_argument(
'--algo_proj_name',
type=str,
help='algo project name',
default='algo_public')
parser.add_argument(
'--algo_version', type=str, help='algo version', default=None)
parser.add_argument(
'--algo_res_proj', type=str, help='algo resource project', default=None)
parser.add_argument(
'--metric_name', type=str, help='evaluate metric name', default='auc')
parser.add_argument(
'--max_parallel',
type=int,
help='max number of trials run at the same time',
default=4)
parser.add_argument(
'--total_trial_num',
type=int,
help='total number of trials will run',
default=6)
parser.add_argument(
'--debug',
action='store_true',
help='debug mode, will keep the temporary folder')
args = parser.parse_args()
assert os.path.exists(args.odps_config)
odps_config = {}
with open(args.odps_config, 'r') as fin:
for line_str in fin:
line_str = line_str.strip()
if len(line_str) == 0:
continue
if line_str[0] == '#':
continue
if '=' in line_str:
tmp_id = line_str.find('=')
key = line_str[:tmp_id].strip()
val = line_str[(tmp_id + 1):].strip()
odps_config[key] = val
if args.oss_config is None:
args.oss_config = os.path.join(os.environ['HOME'], '.ossutilconfig')
assert os.path.exists(args.oss_config)
oss_config = {}
with open(args.oss_config, 'r') as fin:
for line_str in fin:
line_str = line_str.strip()
if len(line_str) == 0:
continue
if line_str[0] == '#':
continue
if '=' in line_str:
tmp_id = line_str.find('=')
key = line_str[:tmp_id].strip()
val = line_str[(tmp_id + 1):].strip()
oss_config[key] = val
assert args.bucket is not None
assert args.role_arn is not None
if args.bucket.startswith('oss://'):
args.bucket = args.bucket[len('oss://'):]
args.bucket = args.bucket.strip('/')
environment = {
'access_id': odps_config['access_id'],
'access_key': odps_config['access_key'],
'oss_access_id': oss_config['accessKeyID'],
'oss_access_key': oss_config['accessKeySecret'],
'project': odps_config['project_name'],
'odps_endpoint': odps_config['end_point'],
'biz_id': '147331^paistudio^xxxxxxx^2020-03-18',
'role_arn': args.role_arn,
'bucket': args.bucket,
'oss_endpoint': oss_config['endpoint']
}
assert args.hyperparams is not None
with open(args.hyperparams, 'r') as fin:
hyperparams = json.load(fin)
assert args.config_path is not None
assert args.exp_dir is not None
assert args.tables is not None or (args.train_tables is not None and
args.eval_tables is not None)
data, tmp_dir = hpo_config(args.config_path, hyperparams, environment,
args.exp_dir, args.tables, args.train_tables,
args.eval_tables, args.cluster,
args.algo_proj_name, args.algo_res_proj,
args.algo_version, args.metric_name,
args.odps_config)
hpo_util.kill_old_proc(tmp_dir, platform='pai')
data_json = json.dumps(data)
tuner = get_tuner(data_json, args.max_parallel, args.total_trial_num)
tuner.fit(synchronize=True)
if not args.debug:
shutil.rmtree(tmp_dir)
else:
logging.info('temporary directory is: %s' % tmp_dir)