pai_jobs/run.py (505 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from __future__ import print_function
import logging
# use few threads to avoid oss error
import os
import time
import tensorflow as tf
import yaml
from tensorflow.python.platform import gfile
import easy_rec
from easy_rec.python.inference.odps_predictor import ODPSPredictor
from easy_rec.python.inference.vector_retrieve import VectorRetrieve
from easy_rec.python.tools.pre_check import run_check
from easy_rec.python.utils import config_util
from easy_rec.python.utils import constant
from easy_rec.python.utils import estimator_utils
from easy_rec.python.utils import fg_util
from easy_rec.python.utils import hpo_util
from easy_rec.python.utils import pai_util
from easy_rec.python.utils.distribution_utils import DistributionStrategyMap
from easy_rec.python.utils.distribution_utils import set_distribution_config
os.environ['IS_ON_PAI'] = '1'
from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num # NOQA
os.environ['OENV_MultiWriteThreadsNum'] = '4'
os.environ['OENV_MultiCopyThreadsNum'] = '4'
if not tf.__version__.startswith('1.12'):
tf = tf.compat.v1
try:
import tensorflow_io as tfio # noqa: F401
except Exception as ex:
logging.error('failed to import tfio: %s' % str(ex))
tf.disable_eager_execution()
from easy_rec.python.main import _train_and_evaluate_impl as train_and_evaluate_impl # NOQA
logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
tf.app.flags.DEFINE_string('worker_hosts', '',
'Comma-separated list of hostname:port pairs')
tf.app.flags.DEFINE_string('ps_hosts', '',
'Comma-separated list of hostname:port pairs')
tf.app.flags.DEFINE_string('job_name', '', 'task type, ps/worker')
tf.app.flags.DEFINE_integer('task_index', 0, 'Index of task within the job')
tf.app.flags.DEFINE_string('config', '', 'EasyRec config file path')
tf.app.flags.DEFINE_string('cmd', 'train',
'command type, train/evaluate/export')
tf.app.flags.DEFINE_string('tables', '', 'tables passed by pai command')
# flags for train
tf.app.flags.DEFINE_integer('num_gpus_per_worker', 1,
'number of gpu to use in training')
tf.app.flags.DEFINE_boolean('with_evaluator', False,
'whether a evaluator is necessary')
tf.app.flags.DEFINE_string(
'eval_method', 'none', 'default to none, choices are [none: not evaluate,' +
'master: evaluate on master, separate: evaluate on a separate task]')
tf.app.flags.DEFINE_string('distribute_strategy', '',
'training distribute strategy')
tf.app.flags.DEFINE_string('edit_config_json', '', 'edit config json string')
tf.app.flags.DEFINE_string('train_tables', '', 'tables used for train')
tf.app.flags.DEFINE_string('eval_tables', '', 'tables used for evaluation')
tf.app.flags.DEFINE_string('boundary_table', '', 'tables used for boundary')
tf.app.flags.DEFINE_string('sampler_table', '', 'tables used for sampler')
tf.app.flags.DEFINE_string('fine_tune_checkpoint', None,
'finetune checkpoint path')
tf.app.flags.DEFINE_string('query_table', '',
'table used for retrieve vector neighbours')
tf.app.flags.DEFINE_string('doc_table', '',
'table used for be retrieved as indexed vectors')
tf.app.flags.DEFINE_enum('knn_distance', 'inner_product',
['l2', 'inner_product'], 'type of knn distance')
tf.app.flags.DEFINE_integer('knn_num_neighbours', None,
'top n neighbours to be retrieved')
tf.app.flags.DEFINE_integer('knn_feature_dims', None,
'number of feature dimensions')
tf.app.flags.DEFINE_enum(
'knn_index_type', 'ivfflat',
['flat', 'ivfflat', 'ivfpq', 'gpu_flat', 'gpu_ivfflat', 'gpu_ivfpg'],
'knn index type')
tf.app.flags.DEFINE_string('knn_feature_delimiter', ',',
'delimiter for feature vectors')
tf.app.flags.DEFINE_integer('knn_nlist', 5,
'number of split part on each worker')
tf.app.flags.DEFINE_integer('knn_nprobe', 2,
'number of probe part on each worker')
tf.app.flags.DEFINE_integer(
'knn_compress_dim', 8,
'number of dimensions after compress for `ivfpq` and `gpu_ivfpq`')
# flags used for evaluate & export
tf.app.flags.DEFINE_string(
'checkpoint_path', '', 'checkpoint to be evaluated or exported '
'if not specified, use the latest checkpoint '
'in train_config.model_dir')
# flags used for evaluate
tf.app.flags.DEFINE_string('eval_result_path', 'eval_result.txt',
'eval result metric file')
tf.app.flags.DEFINE_bool('distribute_eval', False,
'use distribute parameter server for train and eval.')
# flags used for export
tf.app.flags.DEFINE_string('export_dir', '',
'directory where model should be exported to')
tf.app.flags.DEFINE_bool('clear_export', False, 'remove export_dir if exists')
tf.app.flags.DEFINE_string('export_done_file', '',
'a flag file to signal that export model is done')
tf.app.flags.DEFINE_integer('max_wait_ckpt_ts', 0,
'max wait time in seconds for checkpoints')
tf.app.flags.DEFINE_boolean('continue_train', True,
'use the same model to continue train or not')
# flags used for predict
tf.app.flags.DEFINE_string('saved_model_dir', '',
'directory where saved_model.pb exists')
tf.app.flags.DEFINE_string('outputs', '', 'output tables')
tf.app.flags.DEFINE_string(
'all_cols', '',
'union of (selected_cols, reserved_cols), separated with , ')
tf.app.flags.DEFINE_string(
'all_col_types', '',
'column data types, for build record defaults, separated with ,')
tf.app.flags.DEFINE_string(
'selected_cols', '',
'columns to keep from input table, they are separated with ,')
tf.app.flags.DEFINE_string(
'reserved_cols', '',
'columns to keep from input table, they are separated with ,')
tf.app.flags.DEFINE_string(
'output_cols', None,
'output columns, such as: score float. multiple columns are separated by ,')
tf.app.flags.DEFINE_integer('batch_size', 1024, 'predict batch size')
tf.app.flags.DEFINE_string(
'profiling_file', None,
'time stat file which can be viewed using chrome tracing')
tf.app.flags.DEFINE_string('redis_url', None, 'export to redis url, host:port')
tf.app.flags.DEFINE_string('redis_passwd', None, 'export to redis passwd')
tf.app.flags.DEFINE_integer('redis_threads', 5, 'export to redis threads')
tf.app.flags.DEFINE_integer('redis_batch_size', 1024,
'export to redis batch_size')
tf.app.flags.DEFINE_integer('redis_timeout', 600,
'export to redis time_out in seconds')
tf.app.flags.DEFINE_integer('redis_expire', 24,
'export to redis expire time in hour')
tf.app.flags.DEFINE_string('redis_embedding_version', '',
'redis embedding version')
tf.app.flags.DEFINE_integer('redis_write_kv', 1, 'whether write kv ')
tf.app.flags.DEFINE_string(
'oss_path', None, 'write embed objects to oss folder, oss://bucket/folder')
tf.app.flags.DEFINE_string('oss_endpoint', None, 'oss endpoint')
tf.app.flags.DEFINE_string('oss_ak', None, 'oss ak')
tf.app.flags.DEFINE_string('oss_sk', None, 'oss sk')
tf.app.flags.DEFINE_integer('oss_threads', 10,
'# threads access oss at the same time')
tf.app.flags.DEFINE_integer('oss_timeout', 10,
'connect to oss, time_out in seconds')
tf.app.flags.DEFINE_integer('oss_expire', 24, 'oss expire time in hours')
tf.app.flags.DEFINE_integer('oss_write_kv', 1,
'whether to write embedding to oss')
tf.app.flags.DEFINE_string('oss_embedding_version', '', 'oss embedding version')
tf.app.flags.DEFINE_bool('verbose', False, 'print more debug information')
tf.app.flags.DEFINE_bool('place_embedding_on_cpu', False,
'whether to place embedding variables on cpu')
# for automl hyper parameter tuning
tf.app.flags.DEFINE_string('model_dir', None, 'model directory')
tf.app.flags.DEFINE_bool('clear_model', False,
'remove model directory if exists')
tf.app.flags.DEFINE_string('hpo_param_path', None,
'hyperparameter tuning param path')
tf.app.flags.DEFINE_string('hpo_metric_save_path', None,
'hyperparameter save metric path')
tf.app.flags.DEFINE_string('asset_files', None, 'extra files to add to export')
tf.app.flags.DEFINE_bool('check_mode', False, 'is use check mode')
tf.app.flags.DEFINE_string('fg_json_path', None, '')
tf.app.flags.DEFINE_bool('enable_avx_str_split', False,
'enable avx str split to speedup')
FLAGS = tf.app.flags.FLAGS
def check_param(name):
assert getattr(FLAGS, name) != '', '%s should not be empty' % name
def set_selected_cols(pipeline_config, selected_cols, all_cols, all_col_types):
if selected_cols:
pipeline_config.data_config.selected_cols = selected_cols
# add column types which will be used by OdpsInput, OdpsInputV2
# to check consistency with input_fields.input_type
if all_cols:
all_cols_arr = all_cols.split(',')
all_col_types_arr = all_col_types.split(',')
all_col_types_map = {
x.strip(): y.strip() for x, y in zip(all_cols_arr, all_col_types_arr)
}
selected_cols_arr = [x.strip() for x in selected_cols.split(',')]
selected_col_types = [all_col_types_map[x] for x in selected_cols_arr]
selected_col_types = ','.join(selected_col_types)
pipeline_config.data_config.selected_col_types = selected_col_types
print('[run.py] data_config.selected_cols = "%s"' %
pipeline_config.data_config.selected_cols)
print('[run.py] data_config.selected_col_types = "%s"' %
pipeline_config.data_config.selected_col_types)
def _wait_ckpt(ckpt_path, max_wait_ts):
logging.info('will wait %s seconds for checkpoint' % max_wait_ts)
start_ts = time.time()
if '/model.ckpt-' not in ckpt_path:
while time.time() - start_ts < max_wait_ts:
tmp_ckpt = estimator_utils.latest_checkpoint(ckpt_path)
if tmp_ckpt is None:
logging.info('wait for checkpoint in directory[%s]' % ckpt_path)
time.sleep(30)
else:
logging.info('find checkpoint[%s] in directory[%s]' %
(tmp_ckpt, ckpt_path))
break
else:
while time.time() - start_ts < max_wait_ts:
if not gfile.Exists(ckpt_path + '.index'):
logging.info('wait for checkpoint[%s]' % ckpt_path)
time.sleep(30)
else:
logging.info('find checkpoint[%s]' % ckpt_path)
break
def main(argv):
pai_util.set_on_pai()
if FLAGS.enable_avx_str_split:
constant.enable_avx_str_split()
logging.info('will enable avx str split: %s' %
constant.is_avx_str_split_enabled())
if FLAGS.distribute_eval:
os.environ['distribute_eval'] = 'True'
# load lookup op
try:
lookup_op_path = os.path.join(easy_rec.ops_dir, 'libembed_op.so')
tf.load_op_library(lookup_op_path)
except Exception as ex:
print('Error: exception: %s' % str(ex))
num_gpus_per_worker = FLAGS.num_gpus_per_worker
worker_hosts = FLAGS.worker_hosts.split(',')
num_worker = len(worker_hosts)
assert FLAGS.distribute_strategy in DistributionStrategyMap, \
'invalid distribute_strategy [%s], available ones are %s' % (
FLAGS.distribute_strategy, ','.join(DistributionStrategyMap.keys()))
if FLAGS.config:
config = pai_util.process_config(FLAGS.config, FLAGS.task_index,
len(FLAGS.worker_hosts.split(',')))
pipeline_config = config_util.get_configs_from_pipeline_file(config, False)
# should be in front of edit_config_json step
# otherwise data_config and feature_config are not ready
if pipeline_config.fg_json_path:
fg_util.load_fg_json_to_config(pipeline_config)
if FLAGS.edit_config_json:
print('[run.py] edit_config_json = %s' % FLAGS.edit_config_json)
config_json = yaml.safe_load(FLAGS.edit_config_json)
config_util.edit_config(pipeline_config, config_json)
if FLAGS.model_dir:
pipeline_config.model_dir = FLAGS.model_dir
pipeline_config.model_dir = pipeline_config.model_dir.strip()
print('[run.py] update model_dir to %s' % pipeline_config.model_dir)
assert pipeline_config.model_dir.startswith(
'oss://'), 'invalid model_dir format: %s' % pipeline_config.model_dir
if FLAGS.asset_files:
pipeline_config.export_config.asset_files.extend(
FLAGS.asset_files.split(','))
if FLAGS.config:
if not pipeline_config.model_dir.endswith('/'):
pipeline_config.model_dir += '/'
if FLAGS.clear_model:
if gfile.IsDirectory(
pipeline_config.model_dir) and estimator_utils.is_chief():
gfile.DeleteRecursively(pipeline_config.model_dir)
if FLAGS.max_wait_ckpt_ts > 0:
if FLAGS.checkpoint_path:
_wait_ckpt(FLAGS.checkpoint_path, FLAGS.max_wait_ckpt_ts)
else:
_wait_ckpt(pipeline_config.model_dir, FLAGS.max_wait_ckpt_ts)
if FLAGS.cmd == 'train':
assert FLAGS.config, 'config should not be empty when training!'
if not FLAGS.train_tables and FLAGS.tables:
tables = FLAGS.tables.split(',')
assert len(
tables
) >= 2, 'at least 2 tables must be specified, but only[%d]: %s' % (
len(tables), FLAGS.tables)
if FLAGS.train_tables:
pipeline_config.train_input_path = FLAGS.train_tables
elif FLAGS.tables:
pipeline_config.train_input_path = FLAGS.tables.split(',')[0]
if FLAGS.eval_tables:
pipeline_config.eval_input_path = FLAGS.eval_tables
elif FLAGS.tables:
pipeline_config.eval_input_path = FLAGS.tables.split(',')[1]
print('[run.py] train_tables: %s' % pipeline_config.train_input_path)
print('[run.py] eval_tables: %s' % pipeline_config.eval_input_path)
if FLAGS.fine_tune_checkpoint:
pipeline_config.train_config.fine_tune_checkpoint = FLAGS.fine_tune_checkpoint
if pipeline_config.train_config.HasField('fine_tune_checkpoint'):
pipeline_config.train_config.fine_tune_checkpoint = estimator_utils.get_latest_checkpoint_from_checkpoint_path(
pipeline_config.train_config.fine_tune_checkpoint, False)
if FLAGS.boundary_table:
logging.info('Load boundary_table: %s' % FLAGS.boundary_table)
config_util.add_boundaries_to_config(pipeline_config,
FLAGS.boundary_table)
if FLAGS.sampler_table:
pipeline_config.data_config.negative_sampler.input_path = FLAGS.sampler_table
if FLAGS.train_tables or FLAGS.tables:
# parse selected_cols
set_selected_cols(pipeline_config, FLAGS.selected_cols, FLAGS.all_cols,
FLAGS.all_col_types)
else:
pipeline_config.data_config.selected_cols = ''
pipeline_config.data_config.selected_col_types = ''
distribute_strategy = DistributionStrategyMap[FLAGS.distribute_strategy]
# update params specified by automl if hpo_param_path is specified
if FLAGS.hpo_param_path:
logging.info('hpo_param_path = %s' % FLAGS.hpo_param_path)
with gfile.GFile(FLAGS.hpo_param_path, 'r') as fin:
hpo_config = yaml.safe_load(fin)
hpo_params = hpo_config['param']
config_util.edit_config(pipeline_config, hpo_params)
config_util.auto_expand_share_feature_configs(pipeline_config)
print('[run.py] with_evaluator %s' % str(FLAGS.with_evaluator))
print('[run.py] eval_method %s' % FLAGS.eval_method)
assert FLAGS.eval_method in [
'none', 'master', 'separate'
], 'invalid evalaute_method: %s' % FLAGS.eval_method
# with_evaluator is depreciated, keeped for compatibility
if FLAGS.with_evaluator:
FLAGS.eval_method = 'separate'
num_worker = set_tf_config_and_get_train_worker_num(
FLAGS.ps_hosts,
FLAGS.worker_hosts,
FLAGS.task_index,
FLAGS.job_name,
distribute_strategy=distribute_strategy,
eval_method=FLAGS.eval_method)
set_distribution_config(pipeline_config, num_worker, num_gpus_per_worker,
distribute_strategy)
logging.info('run.py check_mode: %s .' % FLAGS.check_mode)
train_and_evaluate_impl(
pipeline_config,
continue_train=FLAGS.continue_train,
check_mode=FLAGS.check_mode)
if FLAGS.hpo_metric_save_path:
hpo_util.save_eval_metrics(
pipeline_config.model_dir,
metric_save_path=FLAGS.hpo_metric_save_path,
has_evaluator=(FLAGS.eval_method == 'separate'))
elif FLAGS.cmd == 'evaluate':
check_param('config')
# TODO: support multi-worker evaluation
if not FLAGS.distribute_eval:
assert len(
FLAGS.worker_hosts.split(',')) == 1, 'evaluate only need 1 worker'
config_util.auto_expand_share_feature_configs(pipeline_config)
if FLAGS.eval_tables:
pipeline_config.eval_input_path = FLAGS.eval_tables
elif FLAGS.tables:
pipeline_config.eval_input_path = FLAGS.tables.split(',')[0]
distribute_strategy = DistributionStrategyMap[FLAGS.distribute_strategy]
set_tf_config_and_get_train_worker_num(
FLAGS.ps_hosts,
FLAGS.worker_hosts,
FLAGS.task_index,
FLAGS.job_name,
eval_method='none')
set_distribution_config(pipeline_config, num_worker, num_gpus_per_worker,
distribute_strategy)
if FLAGS.eval_tables or FLAGS.tables:
# parse selected_cols
set_selected_cols(pipeline_config, FLAGS.selected_cols, FLAGS.all_cols,
FLAGS.all_col_types)
else:
pipeline_config.data_config.selected_cols = ''
pipeline_config.data_config.selected_col_types = ''
if FLAGS.distribute_eval:
os.environ['distribute_eval'] = 'True'
logging.info('will_use_distribute_eval')
distribute_eval = os.environ.get('distribute_eval')
logging.info('distribute_eval = {}'.format(distribute_eval))
easy_rec.distribute_evaluate(pipeline_config, FLAGS.checkpoint_path, None,
FLAGS.eval_result_path)
else:
os.environ['distribute_eval'] = 'False'
logging.info('will_use_eval')
distribute_eval = os.environ.get('distribute_eval')
logging.info('distribute_eval = {}'.format(distribute_eval))
easy_rec.evaluate(pipeline_config, FLAGS.checkpoint_path, None,
FLAGS.eval_result_path)
elif FLAGS.cmd == 'export':
check_param('export_dir')
check_param('config')
if FLAGS.place_embedding_on_cpu:
os.environ['place_embedding_on_cpu'] = 'True'
else:
os.environ['place_embedding_on_cpu'] = 'False'
redis_params = {}
if FLAGS.redis_url:
redis_params['redis_url'] = FLAGS.redis_url
if FLAGS.redis_passwd:
redis_params['redis_passwd'] = FLAGS.redis_passwd
if FLAGS.redis_threads > 0:
redis_params['redis_threads'] = FLAGS.redis_threads
if FLAGS.redis_batch_size > 0:
redis_params['redis_batch_size'] = FLAGS.redis_batch_size
if FLAGS.redis_expire > 0:
redis_params['redis_expire'] = FLAGS.redis_expire
if FLAGS.redis_embedding_version:
redis_params['redis_embedding_version'] = FLAGS.redis_embedding_version
if FLAGS.redis_write_kv:
redis_params['redis_write_kv'] = FLAGS.redis_write_kv
oss_params = {}
if FLAGS.oss_path:
oss_params['oss_path'] = FLAGS.oss_path
if FLAGS.oss_endpoint:
oss_params['oss_endpoint'] = FLAGS.oss_endpoint
if FLAGS.oss_ak:
oss_params['oss_ak'] = FLAGS.oss_ak
if FLAGS.oss_sk:
oss_params['oss_sk'] = FLAGS.oss_sk
if FLAGS.oss_timeout > 0:
oss_params['oss_timeout'] = FLAGS.oss_timeout
if FLAGS.oss_expire > 0:
oss_params['oss_expire'] = FLAGS.oss_expire
if FLAGS.oss_threads > 0:
oss_params['oss_threads'] = FLAGS.oss_threads
if FLAGS.oss_embedding_version:
redis_params['oss_embedding_version'] = FLAGS.oss_embedding_version
if FLAGS.oss_write_kv:
oss_params['oss_write_kv'] = True if FLAGS.oss_write_kv == 1 else False
set_tf_config_and_get_train_worker_num(
FLAGS.ps_hosts,
FLAGS.worker_hosts,
FLAGS.task_index,
FLAGS.job_name,
eval_method='none')
assert len(FLAGS.worker_hosts.split(',')) == 1, 'export only need 1 woker'
config_util.auto_expand_share_feature_configs(pipeline_config)
export_dir = FLAGS.export_dir
if not export_dir.endswith('/'):
export_dir = export_dir + '/'
if FLAGS.clear_export:
if gfile.IsDirectory(export_dir):
gfile.DeleteRecursively(export_dir)
extra_params = redis_params
extra_params.update(oss_params)
export_out_dir = easy_rec.export(export_dir, pipeline_config,
FLAGS.checkpoint_path, FLAGS.asset_files,
FLAGS.verbose, **extra_params)
if FLAGS.export_done_file:
flag_file = os.path.join(export_out_dir, FLAGS.export_done_file)
logging.info('create export done file: %s' % flag_file)
with gfile.GFile(flag_file, 'w') as fout:
fout.write('ExportDone')
elif FLAGS.cmd == 'predict':
check_param('tables')
check_param('saved_model_dir')
logging.info('will use the following columns as model input: %s' %
FLAGS.selected_cols)
logging.info('will copy the following columns to output: %s' %
FLAGS.reserved_cols)
profiling_file = FLAGS.profiling_file if FLAGS.task_index == 0 else None
if profiling_file is not None:
print('profiling_file = %s ' % profiling_file)
predictor = ODPSPredictor(
FLAGS.saved_model_dir,
fg_json_path=FLAGS.fg_json_path,
profiling_file=profiling_file,
all_cols=FLAGS.all_cols,
all_col_types=FLAGS.all_col_types)
input_table, output_table = FLAGS.tables, FLAGS.outputs
logging.info('input_table = %s, output_table = %s' %
(input_table, output_table))
worker_num = len(FLAGS.worker_hosts.split(','))
predictor.predict_impl(
input_table,
output_table,
reserved_cols=FLAGS.reserved_cols,
output_cols=FLAGS.output_cols,
batch_size=FLAGS.batch_size,
slice_id=FLAGS.task_index,
slice_num=worker_num)
elif FLAGS.cmd == 'export_checkpoint':
check_param('export_dir')
check_param('config')
set_tf_config_and_get_train_worker_num(
FLAGS.ps_hosts,
FLAGS.worker_hosts,
FLAGS.task_index,
FLAGS.job_name,
eval_method='none')
assert len(FLAGS.worker_hosts.split(',')) == 1, 'export only need 1 woker'
config_util.auto_expand_share_feature_configs(pipeline_config)
easy_rec.export_checkpoint(
pipeline_config,
export_path=FLAGS.export_dir + '/model',
checkpoint_path=FLAGS.checkpoint_path,
asset_files=FLAGS.asset_files,
verbose=FLAGS.verbose)
elif FLAGS.cmd == 'vector_retrieve':
check_param('knn_distance')
assert FLAGS.knn_feature_dims is not None, '`knn_feature_dims` should not be None'
assert FLAGS.knn_num_neighbours is not None, '`knn_num_neighbours` should not be None'
query_table, doc_table, output_table = FLAGS.query_table, FLAGS.doc_table, FLAGS.outputs
if not query_table:
tables = FLAGS.tables.split(',')
assert len(
tables
) >= 1, 'at least 1 tables must be specified, but only[%d]: %s' % (
len(tables), FLAGS.tables)
query_table = tables[0]
doc_table = tables[1] if len(tables) > 1 else query_table
knn = VectorRetrieve(
query_table,
doc_table,
output_table,
ndim=FLAGS.knn_feature_dims,
distance=1 if FLAGS.knn_distance == 'inner_product' else 0,
delimiter=FLAGS.knn_feature_delimiter,
batch_size=FLAGS.batch_size,
index_type=FLAGS.knn_index_type,
nlist=FLAGS.knn_nlist,
nprobe=FLAGS.knn_nprobe,
m=FLAGS.knn_compress_dim)
worker_hosts = FLAGS.worker_hosts.split(',')
knn(FLAGS.knn_num_neighbours, FLAGS.task_index, len(worker_hosts))
elif FLAGS.cmd == 'check':
run_check(pipeline_config, FLAGS.tables)
else:
raise ValueError(
'cmd should be one of train/evaluate/export/predict/export_checkpoint/vector_retrieve'
)
if __name__ == '__main__':
tf.app.run()