def main()

in pai_jobs/run.py [0:0]


def main(argv):
  pai_util.set_on_pai()
  if FLAGS.enable_avx_str_split:
    constant.enable_avx_str_split()
    logging.info('will enable avx str split: %s' %
                 constant.is_avx_str_split_enabled())

  if FLAGS.distribute_eval:
    os.environ['distribute_eval'] = 'True'

  # load lookup op
  try:
    lookup_op_path = os.path.join(easy_rec.ops_dir, 'libembed_op.so')
    tf.load_op_library(lookup_op_path)
  except Exception as ex:
    print('Error: exception: %s' % str(ex))

  num_gpus_per_worker = FLAGS.num_gpus_per_worker
  worker_hosts = FLAGS.worker_hosts.split(',')
  num_worker = len(worker_hosts)
  assert FLAGS.distribute_strategy in DistributionStrategyMap, \
      'invalid distribute_strategy [%s], available ones are %s' % (
          FLAGS.distribute_strategy, ','.join(DistributionStrategyMap.keys()))

  if FLAGS.config:
    config = pai_util.process_config(FLAGS.config, FLAGS.task_index,
                                     len(FLAGS.worker_hosts.split(',')))
    pipeline_config = config_util.get_configs_from_pipeline_file(config, False)

    # should be in front of edit_config_json step
    # otherwise data_config and feature_config are not ready
    if pipeline_config.fg_json_path:
      fg_util.load_fg_json_to_config(pipeline_config)

  if FLAGS.edit_config_json:
    print('[run.py] edit_config_json = %s' % FLAGS.edit_config_json)
    config_json = yaml.safe_load(FLAGS.edit_config_json)
    config_util.edit_config(pipeline_config, config_json)

  if FLAGS.model_dir:
    pipeline_config.model_dir = FLAGS.model_dir
    pipeline_config.model_dir = pipeline_config.model_dir.strip()
    print('[run.py] update model_dir to %s' % pipeline_config.model_dir)
    assert pipeline_config.model_dir.startswith(
        'oss://'), 'invalid model_dir format: %s' % pipeline_config.model_dir

  if FLAGS.asset_files:
    pipeline_config.export_config.asset_files.extend(
        FLAGS.asset_files.split(','))

  if FLAGS.config:
    if not pipeline_config.model_dir.endswith('/'):
      pipeline_config.model_dir += '/'

  if FLAGS.clear_model:
    if gfile.IsDirectory(
        pipeline_config.model_dir) and estimator_utils.is_chief():
      gfile.DeleteRecursively(pipeline_config.model_dir)

  if FLAGS.max_wait_ckpt_ts > 0:
    if FLAGS.checkpoint_path:
      _wait_ckpt(FLAGS.checkpoint_path, FLAGS.max_wait_ckpt_ts)
    else:
      _wait_ckpt(pipeline_config.model_dir, FLAGS.max_wait_ckpt_ts)

  if FLAGS.cmd == 'train':
    assert FLAGS.config, 'config should not be empty when training!'

    if not FLAGS.train_tables and FLAGS.tables:
      tables = FLAGS.tables.split(',')
      assert len(
          tables
      ) >= 2, 'at least 2 tables must be specified, but only[%d]: %s' % (
          len(tables), FLAGS.tables)

    if FLAGS.train_tables:
      pipeline_config.train_input_path = FLAGS.train_tables
    elif FLAGS.tables:
      pipeline_config.train_input_path = FLAGS.tables.split(',')[0]

    if FLAGS.eval_tables:
      pipeline_config.eval_input_path = FLAGS.eval_tables
    elif FLAGS.tables:
      pipeline_config.eval_input_path = FLAGS.tables.split(',')[1]

    print('[run.py] train_tables: %s' % pipeline_config.train_input_path)
    print('[run.py] eval_tables: %s' % pipeline_config.eval_input_path)

    if FLAGS.fine_tune_checkpoint:
      pipeline_config.train_config.fine_tune_checkpoint = FLAGS.fine_tune_checkpoint

    if pipeline_config.train_config.HasField('fine_tune_checkpoint'):
      pipeline_config.train_config.fine_tune_checkpoint = estimator_utils.get_latest_checkpoint_from_checkpoint_path(
          pipeline_config.train_config.fine_tune_checkpoint, False)

    if FLAGS.boundary_table:
      logging.info('Load boundary_table: %s' % FLAGS.boundary_table)
      config_util.add_boundaries_to_config(pipeline_config,
                                           FLAGS.boundary_table)

    if FLAGS.sampler_table:
      pipeline_config.data_config.negative_sampler.input_path = FLAGS.sampler_table

    if FLAGS.train_tables or FLAGS.tables:
      # parse selected_cols
      set_selected_cols(pipeline_config, FLAGS.selected_cols, FLAGS.all_cols,
                        FLAGS.all_col_types)
    else:
      pipeline_config.data_config.selected_cols = ''
      pipeline_config.data_config.selected_col_types = ''

    distribute_strategy = DistributionStrategyMap[FLAGS.distribute_strategy]

    # update params specified by automl if hpo_param_path is specified
    if FLAGS.hpo_param_path:
      logging.info('hpo_param_path = %s' % FLAGS.hpo_param_path)
      with gfile.GFile(FLAGS.hpo_param_path, 'r') as fin:
        hpo_config = yaml.safe_load(fin)
        hpo_params = hpo_config['param']
        config_util.edit_config(pipeline_config, hpo_params)
    config_util.auto_expand_share_feature_configs(pipeline_config)

    print('[run.py] with_evaluator %s' % str(FLAGS.with_evaluator))
    print('[run.py] eval_method %s' % FLAGS.eval_method)
    assert FLAGS.eval_method in [
        'none', 'master', 'separate'
    ], 'invalid evalaute_method: %s' % FLAGS.eval_method

    # with_evaluator is depreciated, keeped for compatibility
    if FLAGS.with_evaluator:
      FLAGS.eval_method = 'separate'

    num_worker = set_tf_config_and_get_train_worker_num(
        FLAGS.ps_hosts,
        FLAGS.worker_hosts,
        FLAGS.task_index,
        FLAGS.job_name,
        distribute_strategy=distribute_strategy,
        eval_method=FLAGS.eval_method)
    set_distribution_config(pipeline_config, num_worker, num_gpus_per_worker,
                            distribute_strategy)
    logging.info('run.py check_mode: %s .' % FLAGS.check_mode)
    train_and_evaluate_impl(
        pipeline_config,
        continue_train=FLAGS.continue_train,
        check_mode=FLAGS.check_mode)

    if FLAGS.hpo_metric_save_path:
      hpo_util.save_eval_metrics(
          pipeline_config.model_dir,
          metric_save_path=FLAGS.hpo_metric_save_path,
          has_evaluator=(FLAGS.eval_method == 'separate'))

  elif FLAGS.cmd == 'evaluate':
    check_param('config')
    # TODO: support multi-worker evaluation
    if not FLAGS.distribute_eval:
      assert len(
          FLAGS.worker_hosts.split(',')) == 1, 'evaluate only need 1 worker'
    config_util.auto_expand_share_feature_configs(pipeline_config)

    if FLAGS.eval_tables:
      pipeline_config.eval_input_path = FLAGS.eval_tables
    elif FLAGS.tables:
      pipeline_config.eval_input_path = FLAGS.tables.split(',')[0]

    distribute_strategy = DistributionStrategyMap[FLAGS.distribute_strategy]
    set_tf_config_and_get_train_worker_num(
        FLAGS.ps_hosts,
        FLAGS.worker_hosts,
        FLAGS.task_index,
        FLAGS.job_name,
        eval_method='none')
    set_distribution_config(pipeline_config, num_worker, num_gpus_per_worker,
                            distribute_strategy)

    if FLAGS.eval_tables or FLAGS.tables:
      # parse selected_cols
      set_selected_cols(pipeline_config, FLAGS.selected_cols, FLAGS.all_cols,
                        FLAGS.all_col_types)
    else:
      pipeline_config.data_config.selected_cols = ''
      pipeline_config.data_config.selected_col_types = ''

    if FLAGS.distribute_eval:
      os.environ['distribute_eval'] = 'True'
      logging.info('will_use_distribute_eval')
      distribute_eval = os.environ.get('distribute_eval')
      logging.info('distribute_eval = {}'.format(distribute_eval))
      easy_rec.distribute_evaluate(pipeline_config, FLAGS.checkpoint_path, None,
                                   FLAGS.eval_result_path)
    else:
      os.environ['distribute_eval'] = 'False'
      logging.info('will_use_eval')
      distribute_eval = os.environ.get('distribute_eval')
      logging.info('distribute_eval = {}'.format(distribute_eval))
      easy_rec.evaluate(pipeline_config, FLAGS.checkpoint_path, None,
                        FLAGS.eval_result_path)
  elif FLAGS.cmd == 'export':
    check_param('export_dir')
    check_param('config')
    if FLAGS.place_embedding_on_cpu:
      os.environ['place_embedding_on_cpu'] = 'True'
    else:
      os.environ['place_embedding_on_cpu'] = 'False'
    redis_params = {}
    if FLAGS.redis_url:
      redis_params['redis_url'] = FLAGS.redis_url
    if FLAGS.redis_passwd:
      redis_params['redis_passwd'] = FLAGS.redis_passwd
    if FLAGS.redis_threads > 0:
      redis_params['redis_threads'] = FLAGS.redis_threads
    if FLAGS.redis_batch_size > 0:
      redis_params['redis_batch_size'] = FLAGS.redis_batch_size
    if FLAGS.redis_expire > 0:
      redis_params['redis_expire'] = FLAGS.redis_expire
    if FLAGS.redis_embedding_version:
      redis_params['redis_embedding_version'] = FLAGS.redis_embedding_version
    if FLAGS.redis_write_kv:
      redis_params['redis_write_kv'] = FLAGS.redis_write_kv

    oss_params = {}
    if FLAGS.oss_path:
      oss_params['oss_path'] = FLAGS.oss_path
    if FLAGS.oss_endpoint:
      oss_params['oss_endpoint'] = FLAGS.oss_endpoint
    if FLAGS.oss_ak:
      oss_params['oss_ak'] = FLAGS.oss_ak
    if FLAGS.oss_sk:
      oss_params['oss_sk'] = FLAGS.oss_sk
    if FLAGS.oss_timeout > 0:
      oss_params['oss_timeout'] = FLAGS.oss_timeout
    if FLAGS.oss_expire > 0:
      oss_params['oss_expire'] = FLAGS.oss_expire
    if FLAGS.oss_threads > 0:
      oss_params['oss_threads'] = FLAGS.oss_threads
    if FLAGS.oss_embedding_version:
      redis_params['oss_embedding_version'] = FLAGS.oss_embedding_version
    if FLAGS.oss_write_kv:
      oss_params['oss_write_kv'] = True if FLAGS.oss_write_kv == 1 else False

    set_tf_config_and_get_train_worker_num(
        FLAGS.ps_hosts,
        FLAGS.worker_hosts,
        FLAGS.task_index,
        FLAGS.job_name,
        eval_method='none')

    assert len(FLAGS.worker_hosts.split(',')) == 1, 'export only need 1 woker'
    config_util.auto_expand_share_feature_configs(pipeline_config)

    export_dir = FLAGS.export_dir
    if not export_dir.endswith('/'):
      export_dir = export_dir + '/'
    if FLAGS.clear_export:
      if gfile.IsDirectory(export_dir):
        gfile.DeleteRecursively(export_dir)

    extra_params = redis_params
    extra_params.update(oss_params)
    export_out_dir = easy_rec.export(export_dir, pipeline_config,
                                     FLAGS.checkpoint_path, FLAGS.asset_files,
                                     FLAGS.verbose, **extra_params)
    if FLAGS.export_done_file:
      flag_file = os.path.join(export_out_dir, FLAGS.export_done_file)
      logging.info('create export done file: %s' % flag_file)
      with gfile.GFile(flag_file, 'w') as fout:
        fout.write('ExportDone')
  elif FLAGS.cmd == 'predict':
    check_param('tables')
    check_param('saved_model_dir')
    logging.info('will use the following columns as model input: %s' %
                 FLAGS.selected_cols)
    logging.info('will copy the following columns to output: %s' %
                 FLAGS.reserved_cols)

    profiling_file = FLAGS.profiling_file if FLAGS.task_index == 0 else None
    if profiling_file is not None:
      print('profiling_file = %s ' % profiling_file)
    predictor = ODPSPredictor(
        FLAGS.saved_model_dir,
        fg_json_path=FLAGS.fg_json_path,
        profiling_file=profiling_file,
        all_cols=FLAGS.all_cols,
        all_col_types=FLAGS.all_col_types)
    input_table, output_table = FLAGS.tables, FLAGS.outputs
    logging.info('input_table = %s, output_table = %s' %
                 (input_table, output_table))
    worker_num = len(FLAGS.worker_hosts.split(','))
    predictor.predict_impl(
        input_table,
        output_table,
        reserved_cols=FLAGS.reserved_cols,
        output_cols=FLAGS.output_cols,
        batch_size=FLAGS.batch_size,
        slice_id=FLAGS.task_index,
        slice_num=worker_num)
  elif FLAGS.cmd == 'export_checkpoint':
    check_param('export_dir')
    check_param('config')
    set_tf_config_and_get_train_worker_num(
        FLAGS.ps_hosts,
        FLAGS.worker_hosts,
        FLAGS.task_index,
        FLAGS.job_name,
        eval_method='none')
    assert len(FLAGS.worker_hosts.split(',')) == 1, 'export only need 1 woker'
    config_util.auto_expand_share_feature_configs(pipeline_config)
    easy_rec.export_checkpoint(
        pipeline_config,
        export_path=FLAGS.export_dir + '/model',
        checkpoint_path=FLAGS.checkpoint_path,
        asset_files=FLAGS.asset_files,
        verbose=FLAGS.verbose)
  elif FLAGS.cmd == 'vector_retrieve':
    check_param('knn_distance')
    assert FLAGS.knn_feature_dims is not None, '`knn_feature_dims` should not be None'
    assert FLAGS.knn_num_neighbours is not None, '`knn_num_neighbours` should not be None'

    query_table, doc_table, output_table = FLAGS.query_table, FLAGS.doc_table, FLAGS.outputs
    if not query_table:
      tables = FLAGS.tables.split(',')
      assert len(
          tables
      ) >= 1, 'at least 1 tables must be specified, but only[%d]: %s' % (
          len(tables), FLAGS.tables)
      query_table = tables[0]
      doc_table = tables[1] if len(tables) > 1 else query_table

    knn = VectorRetrieve(
        query_table,
        doc_table,
        output_table,
        ndim=FLAGS.knn_feature_dims,
        distance=1 if FLAGS.knn_distance == 'inner_product' else 0,
        delimiter=FLAGS.knn_feature_delimiter,
        batch_size=FLAGS.batch_size,
        index_type=FLAGS.knn_index_type,
        nlist=FLAGS.knn_nlist,
        nprobe=FLAGS.knn_nprobe,
        m=FLAGS.knn_compress_dim)
    worker_hosts = FLAGS.worker_hosts.split(',')
    knn(FLAGS.knn_num_neighbours, FLAGS.task_index, len(worker_hosts))
  elif FLAGS.cmd == 'check':
    run_check(pipeline_config, FLAGS.tables)
  else:
    raise ValueError(
        'cmd should be one of train/evaluate/export/predict/export_checkpoint/vector_retrieve'
    )