in easy_rec/python/main.py [0:0]
def _train_and_evaluate_impl(pipeline_config,
continue_train=False,
check_mode=False,
fit_on_eval=False,
fit_on_eval_steps=None):
train_config = pipeline_config.train_config
data_config = pipeline_config.data_config
feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
if train_config.train_distribute != DistributionStrategy.NoStrategy\
and train_config.sync_replicas:
logging.warning(
'will set sync_replicas to False, because train_distribute[%s] != NoStrategy'
% pipeline_config.train_config.train_distribute)
pipeline_config.train_config.sync_replicas = False
train_data = get_train_input_path(pipeline_config)
eval_data = get_eval_input_path(pipeline_config)
distribution = strategy_builder.build(train_config)
params = {}
if train_config.is_profiling:
params['log_device_placement'] = True
estimator, run_config = _create_estimator(
pipeline_config, distribution=distribution, params=params)
version_file = os.path.join(pipeline_config.model_dir, 'version')
if estimator_utils.is_chief():
_check_model_dir(pipeline_config.model_dir, continue_train)
config_util.save_pipeline_config(pipeline_config, pipeline_config.model_dir)
with gfile.GFile(version_file, 'w') as f:
f.write(easy_rec.__version__ + '\n')
train_steps = None
if train_config.HasField('num_steps') and train_config.num_steps > 0:
train_steps = train_config.num_steps
assert train_steps is not None or data_config.num_epochs > 0, (
'either num_steps and num_epochs must be set to an integer > 0.')
if train_steps and data_config.num_epochs:
logging.info('Both num_steps and num_epochs are set.')
is_sync = train_config.sync_replicas
batch_size = data_config.batch_size
epoch_str = 'sample_num * %d / %d' % (data_config.num_epochs, batch_size)
if is_sync:
_, worker_num = estimator_utils.get_task_index_and_num()
epoch_str += ' / ' + str(worker_num)
logging.info('Will train min(%d, %s) steps...' % (train_steps, epoch_str))
input_fn_kwargs = {'pipeline_config': pipeline_config}
if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
# create train input
train_input_fn = _get_input_fn(
data_config,
feature_configs,
train_data,
check_mode=check_mode,
**input_fn_kwargs)
# Currently only a single Eval Spec is allowed.
train_spec = tf.estimator.TrainSpec(
input_fn=train_input_fn, max_steps=train_steps)
embedding_parallel = train_config.train_distribute in (
DistributionStrategy.SokStrategy,
DistributionStrategy.EmbeddingParallelStrategy)
if embedding_parallel:
estimator.train(
input_fn=train_input_fn,
max_steps=train_spec.max_steps,
hooks=list(train_spec.hooks),
saving_listeners=train_spec.saving_listeners)
train_input_fn.input_creator.stop()
else:
# create eval spec
eval_spec = _create_eval_export_spec(
pipeline_config, eval_data, check_mode=check_mode)
estimator_train.train_and_evaluate(estimator, train_spec, eval_spec)
logging.info('Train and evaluate finish')
if fit_on_eval and (not estimator_utils.is_evaluator()):
tf.reset_default_graph()
logging.info('Start continue training on eval data')
eval_input_fn = _get_input_fn(data_config, feature_configs, eval_data,
**input_fn_kwargs)
if fit_on_eval_steps is not None:
# wait estimator train done to get the correct train_steps
while not estimator_train.estimator_train_done(estimator):
time.sleep(1)
train_steps = estimator_utils.get_trained_steps(estimator.model_dir)
logging.info('\ttrain_steps=%d fit_on_eval_steps=%d' %
(train_steps, fit_on_eval_steps))
fit_on_eval_steps += train_steps
# Do not use estimator_train.train_and_evaluate as it starts tf.Server,
# which is redundant and reports port not available error.
estimator.train(
input_fn=eval_input_fn,
max_steps=fit_on_eval_steps,
hooks=list(train_spec.hooks),
saving_listeners=train_spec.saving_listeners if hasattr(
train_spec, 'saving_listeners') else None)
logging.info('Finished training on eval data')
# return estimator for custom training using estimator.train
return estimator