in easy_rec/python/model/easy_rec_estimator.py [0:0]
def _train_model_fn(self, features, labels, run_config):
tf.keras.backend.set_learning_phase(1)
model = self._model_cls(
self.model_config,
self.feature_configs,
features,
labels,
is_training=True)
predict_dict = model.build_predict_graph()
loss_dict = model.build_loss_graph()
regularization_losses = tf.get_collection(
tf.GraphKeys.REGULARIZATION_LOSSES)
if regularization_losses:
regularization_losses = [
reg_loss.get() if hasattr(reg_loss, 'get') else reg_loss
for reg_loss in regularization_losses
]
regularization_losses = tf.add_n(
regularization_losses, name='regularization_loss')
loss_dict['regularization_loss'] = regularization_losses
variational_dropout_loss = tf.get_collection('variational_dropout_loss')
if variational_dropout_loss:
variational_dropout_loss = tf.add_n(
variational_dropout_loss, name='variational_dropout_loss')
loss_dict['variational_dropout_loss'] = variational_dropout_loss
loss = tf.add_n(list(loss_dict.values()))
loss_dict['total_loss'] = loss
for key in loss_dict:
tf.summary.scalar(key, loss_dict[key], family='loss')
if Input.DATA_OFFSET in features:
task_index, task_num = estimator_utils.get_task_index_and_num()
data_offset_var = tf.get_variable(
name=Input.DATA_OFFSET,
dtype=tf.string,
shape=[task_num],
collections=[tf.GraphKeys.GLOBAL_VARIABLES, Input.DATA_OFFSET],
trainable=False)
update_offset = tf.assign(data_offset_var[task_index],
features[Input.DATA_OFFSET])
ops.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_offset)
else:
data_offset_var = None
# update op, usually used for batch-norm
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
# register for increment update, such as batchnorm moving_mean and moving_variance
global_vars = {x.name: x for x in tf.global_variables()}
for x in update_ops:
if isinstance(x, ops.Operation) and x.inputs[0].name in global_vars:
ops.add_to_collection(constant.DENSE_UPDATE_VARIABLES,
global_vars[x.inputs[0].name])
update_op = tf.group(*update_ops, name='update_barrier')
with tf.control_dependencies([update_op]):
loss = tf.identity(loss, name='total_loss')
# build optimizer
if len(self.train_config.optimizer_config) == 1:
optimizer_config = self.train_config.optimizer_config[0]
optimizer, learning_rate = optimizer_builder.build(optimizer_config)
tf.summary.scalar('learning_rate', learning_rate[0])
else:
optimizer_config = self.train_config.optimizer_config
all_opts = []
for opti_id, tmp_config in enumerate(optimizer_config):
with tf.name_scope('optimizer_%d' % opti_id):
opt, learning_rate = optimizer_builder.build(tmp_config)
tf.summary.scalar('learning_rate', learning_rate[0])
all_opts.append(opt)
grouped_vars = model.get_grouped_vars(len(all_opts))
assert len(grouped_vars) == len(optimizer_config), \
'the number of var group(%d) != the number of optimizers(%d)' \
% (len(grouped_vars), len(optimizer_config))
optimizer = MultiOptimizer(all_opts, grouped_vars)
if self.train_config.train_distribute == DistributionStrategy.SokStrategy:
optimizer = sok_optimizer.OptimizerWrapper(optimizer)
hooks = []
if estimator_utils.has_hvd():
assert not self.train_config.sync_replicas, \
'sync_replicas should not be set when using horovod'
bcast_hook = hvd_utils.BroadcastGlobalVariablesHook(0)
hooks.append(bcast_hook)
# for distributed and synced training
if self.train_config.sync_replicas and run_config.num_worker_replicas > 1:
logging.info('sync_replicas: num_worker_replias = %d' %
run_config.num_worker_replicas)
if pai_util.is_on_pai():
optimizer = tf.train.SyncReplicasOptimizer(
optimizer,
replicas_to_aggregate=run_config.num_worker_replicas,
total_num_replicas=run_config.num_worker_replicas,
sparse_accumulator_type=self.train_config.sparse_accumulator_type)
else:
optimizer = sync_replicas_optimizer.SyncReplicasOptimizer(
optimizer,
replicas_to_aggregate=run_config.num_worker_replicas,
total_num_replicas=run_config.num_worker_replicas)
hooks.append(
optimizer.make_session_run_hook(run_config.is_chief, num_tokens=0))
# add barrier for no strategy case
if run_config.num_worker_replicas > 1 and \
self.train_config.train_distribute == DistributionStrategy.NoStrategy:
hooks.append(
estimator_utils.ExitBarrierHook(run_config.num_worker_replicas,
run_config.is_chief, self.model_dir))
if self.export_config.enable_early_stop:
eval_dir = os.path.join(self._model_dir, 'eval_val')
logging.info('will use early stop, eval_events_dir=%s' % eval_dir)
if self.export_config.HasField('early_stop_func'):
hooks.append(
custom_early_stop_hook(
self,
eval_dir=eval_dir,
custom_stop_func=self.export_config.early_stop_func,
custom_stop_func_params=self.export_config.early_stop_params))
elif self.export_config.metric_bigger:
hooks.append(
stop_if_no_increase_hook(
self,
self.export_config.best_exporter_metric,
self.export_config.max_check_steps,
eval_dir=eval_dir))
else:
hooks.append(
stop_if_no_decrease_hook(
self,
self.export_config.best_exporter_metric,
self.export_config.max_check_steps,
eval_dir=eval_dir))
if self.train_config.enable_oss_stop_signal:
hooks.append(oss_stop_hook(self))
if self.train_config.HasField('dead_line'):
hooks.append(deadline_stop_hook(self, self.train_config.dead_line))
summaries = ['global_gradient_norm']
if self.train_config.summary_model_vars:
summaries.extend(['gradient_norm', 'gradients'])
gradient_clipping_by_norm = self.train_config.gradient_clipping_by_norm
if gradient_clipping_by_norm <= 0:
gradient_clipping_by_norm = None
gradient_multipliers = None
if self.train_config.optimizer_config[0].HasField(
'embedding_learning_rate_multiplier'):
gradient_multipliers = {
var: self.train_config.optimizer_config[0]
.embedding_learning_rate_multiplier
for var in tf.trainable_variables()
if 'embedding_weights:' in var.name or
'/embedding_weights/part_' in var.name
}
# optimize loss
# colocate_gradients_with_ops=True means to compute gradients
# on the same device on which op is processes in forward process
all_train_vars = []
if len(self.train_config.freeze_gradient) > 0:
for one_var in tf.trainable_variables():
is_freeze = False
for x in self.train_config.freeze_gradient:
if re.search(x, one_var.name) is not None:
logging.info('will freeze gradients of %s' % one_var.name)
is_freeze = True
break
if not is_freeze:
all_train_vars.append(one_var)
else:
all_train_vars = tf.trainable_variables()
if self.embedding_parallel:
logging.info('embedding_parallel is enabled')
train_op = optimizers.optimize_loss(
loss=loss,
global_step=tf.train.get_global_step(),
learning_rate=None,
clip_gradients=gradient_clipping_by_norm,
optimizer=optimizer,
gradient_multipliers=gradient_multipliers,
variables=all_train_vars,
summaries=summaries,
colocate_gradients_with_ops=True,
not_apply_grad_after_first_step=run_config.is_chief and
self._pipeline_config.data_config.chief_redundant,
name='', # Preventing scope prefix on all variables.
incr_save=(self.incr_save_config is not None),
embedding_parallel=self.embedding_parallel)
# online evaluation
metric_update_op_dict = None
if self.eval_config.eval_online:
metric_update_op_dict = {}
metric_dict = model.build_metric_graph(self.eval_config)
for k, v in metric_dict.items():
metric_update_op_dict['%s/batch' % k] = v[1]
if isinstance(v[1], tf.Tensor):
tf.summary.scalar('%s/batch' % k, v[1])
train_op = tf.group([train_op] + list(metric_update_op_dict.values()))
if estimator_utils.is_chief():
hooks.append(
estimator_utils.OnlineEvaluationHook(
metric_dict=metric_dict, output_dir=self.model_dir))
if self.train_config.HasField('fine_tune_checkpoint'):
fine_tune_ckpt = self.train_config.fine_tune_checkpoint
logging.warning('will restore from %s' % fine_tune_ckpt)
fine_tune_ckpt_var_map = self.train_config.fine_tune_ckpt_var_map
force_restore = self.train_config.force_restore_shape_compatible
restore_hook = model.restore(
fine_tune_ckpt,
include_global_step=False,
ckpt_var_map_path=fine_tune_ckpt_var_map,
force_restore_shape_compatible=force_restore)
if restore_hook is not None:
hooks.append(restore_hook)
# logging
logging_dict = OrderedDict()
logging_dict['step'] = tf.train.get_global_step()
logging_dict['lr'] = learning_rate[0]
logging_dict.update(loss_dict)
if metric_update_op_dict is not None:
logging_dict.update(metric_update_op_dict)
log_step_count_steps = self.train_config.log_step_count_steps
logging_hook = basic_session_run_hooks.LoggingTensorHook(
logging_dict,
every_n_iter=log_step_count_steps,
formatter=estimator_utils.tensor_log_format_func)
hooks.append(logging_hook)
if self.train_config.train_distribute in [
DistributionStrategy.CollectiveAllReduceStrategy,
DistributionStrategy.MirroredStrategy,
DistributionStrategy.MultiWorkerMirroredStrategy
]:
# for multi worker strategy, we could not replace the
# inner CheckpointSaverHook, so just use it.
scaffold = tf.train.Scaffold()
else:
var_list = (
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
# exclude data_offset_var
var_list = [x for x in var_list if x != data_offset_var]
# early_stop flag will not be saved in checkpoint
# and could not be restored from checkpoint
early_stop_var = find_early_stop_var(var_list)
var_list = [x for x in var_list if x != early_stop_var]
initialize_var_list = [
x for x in var_list if 'WorkQueue' not in str(type(x))
]
# incompatiable shape restore will not be saved in checkpoint
# but must be able to restore from checkpoint
incompatiable_shape_restore = tf.get_collection('T_E_M_P_RESTROE')
local_init_ops = [tf.train.Scaffold.default_local_init_op()]
if data_offset_var is not None and estimator_utils.is_chief():
local_init_ops.append(tf.initializers.variables([data_offset_var]))
if early_stop_var is not None and estimator_utils.is_chief():
local_init_ops.append(tf.initializers.variables([early_stop_var]))
if len(incompatiable_shape_restore) > 0:
local_init_ops.append(
tf.initializers.variables(incompatiable_shape_restore))
scaffold = tf.train.Scaffold(
saver=self.saver_cls(
var_list=var_list,
sharded=True,
max_to_keep=self.train_config.keep_checkpoint_max,
save_relative_paths=True),
local_init_op=tf.group(local_init_ops),
ready_for_local_init_op=tf.report_uninitialized_variables(
var_list=initialize_var_list))
# saver hook
saver_hook = estimator_utils.CheckpointSaverHook(
checkpoint_dir=self.model_dir,
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=scaffold,
write_graph=self.train_config.write_graph,
data_offset_var=data_offset_var,
increment_save_config=self.incr_save_config)
if estimator_utils.is_chief() or self.embedding_parallel:
hooks.append(saver_hook)
if estimator_utils.is_chief():
hooks.append(
basic_session_run_hooks.StepCounterHook(
every_n_steps=log_step_count_steps, output_dir=self.model_dir))
# profiling hook
if self.train_config.is_profiling and estimator_utils.is_chief():
profile_hook = tf.train.ProfilerHook(
save_steps=log_step_count_steps, output_dir=self.model_dir)
hooks.append(profile_hook)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
predictions=predict_dict,
train_op=train_op,
scaffold=scaffold,
training_hooks=hooks)