def _train_and_evaluate_impl()

in easy_rec/python/main.py [0:0]


def _train_and_evaluate_impl(pipeline_config,
                             continue_train=False,
                             check_mode=False,
                             fit_on_eval=False,
                             fit_on_eval_steps=None):
  train_config = pipeline_config.train_config
  data_config = pipeline_config.data_config
  feature_configs = config_util.get_compatible_feature_configs(pipeline_config)

  if train_config.train_distribute != DistributionStrategy.NoStrategy\
      and train_config.sync_replicas:
    logging.warning(
        'will set sync_replicas to False, because train_distribute[%s] != NoStrategy'
        % pipeline_config.train_config.train_distribute)
    pipeline_config.train_config.sync_replicas = False

  train_data = get_train_input_path(pipeline_config)
  eval_data = get_eval_input_path(pipeline_config)

  distribution = strategy_builder.build(train_config)
  params = {}
  if train_config.is_profiling:
    params['log_device_placement'] = True
  estimator, run_config = _create_estimator(
      pipeline_config, distribution=distribution, params=params)

  version_file = os.path.join(pipeline_config.model_dir, 'version')
  if estimator_utils.is_chief():
    _check_model_dir(pipeline_config.model_dir, continue_train)
    config_util.save_pipeline_config(pipeline_config, pipeline_config.model_dir)
    with gfile.GFile(version_file, 'w') as f:
      f.write(easy_rec.__version__ + '\n')

  train_steps = None
  if train_config.HasField('num_steps') and train_config.num_steps > 0:
    train_steps = train_config.num_steps
  assert train_steps is not None or data_config.num_epochs > 0, (
      'either num_steps and num_epochs must be set to an integer > 0.')

  if train_steps and data_config.num_epochs:
    logging.info('Both num_steps and num_epochs are set.')
    is_sync = train_config.sync_replicas
    batch_size = data_config.batch_size
    epoch_str = 'sample_num * %d / %d' % (data_config.num_epochs, batch_size)
    if is_sync:
      _, worker_num = estimator_utils.get_task_index_and_num()
      epoch_str += ' / ' + str(worker_num)
    logging.info('Will train min(%d, %s) steps...' % (train_steps, epoch_str))

  input_fn_kwargs = {'pipeline_config': pipeline_config}
  if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
    input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path

  # create train input
  train_input_fn = _get_input_fn(
      data_config,
      feature_configs,
      train_data,
      check_mode=check_mode,
      **input_fn_kwargs)
  # Currently only a single Eval Spec is allowed.
  train_spec = tf.estimator.TrainSpec(
      input_fn=train_input_fn, max_steps=train_steps)

  embedding_parallel = train_config.train_distribute in (
      DistributionStrategy.SokStrategy,
      DistributionStrategy.EmbeddingParallelStrategy)

  if embedding_parallel:
    estimator.train(
        input_fn=train_input_fn,
        max_steps=train_spec.max_steps,
        hooks=list(train_spec.hooks),
        saving_listeners=train_spec.saving_listeners)
    train_input_fn.input_creator.stop()
  else:
    # create eval spec
    eval_spec = _create_eval_export_spec(
        pipeline_config, eval_data, check_mode=check_mode)
    estimator_train.train_and_evaluate(estimator, train_spec, eval_spec)
  logging.info('Train and evaluate finish')
  if fit_on_eval and (not estimator_utils.is_evaluator()):
    tf.reset_default_graph()
    logging.info('Start continue training on eval data')
    eval_input_fn = _get_input_fn(data_config, feature_configs, eval_data,
                                  **input_fn_kwargs)
    if fit_on_eval_steps is not None:
      # wait estimator train done to get the correct train_steps
      while not estimator_train.estimator_train_done(estimator):
        time.sleep(1)
      train_steps = estimator_utils.get_trained_steps(estimator.model_dir)
      logging.info('\ttrain_steps=%d fit_on_eval_steps=%d' %
                   (train_steps, fit_on_eval_steps))
      fit_on_eval_steps += train_steps
    # Do not use estimator_train.train_and_evaluate as it starts tf.Server,
    # which is redundant and reports port not available error.
    estimator.train(
        input_fn=eval_input_fn,
        max_steps=fit_on_eval_steps,
        hooks=list(train_spec.hooks),
        saving_listeners=train_spec.saving_listeners if hasattr(
            train_spec, 'saving_listeners') else None)
    logging.info('Finished training on eval data')
  # return estimator for custom training using estimator.train
  return estimator