def _train_model_fn()

in easy_rec/python/model/easy_rec_estimator.py [0:0]


  def _train_model_fn(self, features, labels, run_config):
    tf.keras.backend.set_learning_phase(1)
    model = self._model_cls(
        self.model_config,
        self.feature_configs,
        features,
        labels,
        is_training=True)
    predict_dict = model.build_predict_graph()
    loss_dict = model.build_loss_graph()

    regularization_losses = tf.get_collection(
        tf.GraphKeys.REGULARIZATION_LOSSES)
    if regularization_losses:
      regularization_losses = [
          reg_loss.get() if hasattr(reg_loss, 'get') else reg_loss
          for reg_loss in regularization_losses
      ]
      regularization_losses = tf.add_n(
          regularization_losses, name='regularization_loss')
      loss_dict['regularization_loss'] = regularization_losses

    variational_dropout_loss = tf.get_collection('variational_dropout_loss')
    if variational_dropout_loss:
      variational_dropout_loss = tf.add_n(
          variational_dropout_loss, name='variational_dropout_loss')
      loss_dict['variational_dropout_loss'] = variational_dropout_loss

    loss = tf.add_n(list(loss_dict.values()))
    loss_dict['total_loss'] = loss
    for key in loss_dict:
      tf.summary.scalar(key, loss_dict[key], family='loss')

    if Input.DATA_OFFSET in features:
      task_index, task_num = estimator_utils.get_task_index_and_num()
      data_offset_var = tf.get_variable(
          name=Input.DATA_OFFSET,
          dtype=tf.string,
          shape=[task_num],
          collections=[tf.GraphKeys.GLOBAL_VARIABLES, Input.DATA_OFFSET],
          trainable=False)
      update_offset = tf.assign(data_offset_var[task_index],
                                features[Input.DATA_OFFSET])
      ops.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_offset)
    else:
      data_offset_var = None

    # update op, usually used for batch-norm
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if update_ops:
      # register for increment update, such as batchnorm moving_mean and moving_variance
      global_vars = {x.name: x for x in tf.global_variables()}
      for x in update_ops:
        if isinstance(x, ops.Operation) and x.inputs[0].name in global_vars:
          ops.add_to_collection(constant.DENSE_UPDATE_VARIABLES,
                                global_vars[x.inputs[0].name])
      update_op = tf.group(*update_ops, name='update_barrier')
      with tf.control_dependencies([update_op]):
        loss = tf.identity(loss, name='total_loss')

    # build optimizer
    if len(self.train_config.optimizer_config) == 1:
      optimizer_config = self.train_config.optimizer_config[0]
      optimizer, learning_rate = optimizer_builder.build(optimizer_config)
      tf.summary.scalar('learning_rate', learning_rate[0])
    else:
      optimizer_config = self.train_config.optimizer_config
      all_opts = []
      for opti_id, tmp_config in enumerate(optimizer_config):
        with tf.name_scope('optimizer_%d' % opti_id):
          opt, learning_rate = optimizer_builder.build(tmp_config)
          tf.summary.scalar('learning_rate', learning_rate[0])
        all_opts.append(opt)
      grouped_vars = model.get_grouped_vars(len(all_opts))
      assert len(grouped_vars) == len(optimizer_config), \
          'the number of var group(%d) != the number of optimizers(%d)' \
          % (len(grouped_vars), len(optimizer_config))
      optimizer = MultiOptimizer(all_opts, grouped_vars)

    if self.train_config.train_distribute == DistributionStrategy.SokStrategy:
      optimizer = sok_optimizer.OptimizerWrapper(optimizer)

    hooks = []
    if estimator_utils.has_hvd():
      assert not self.train_config.sync_replicas, \
          'sync_replicas should not be set when using horovod'
      bcast_hook = hvd_utils.BroadcastGlobalVariablesHook(0)
      hooks.append(bcast_hook)

    # for distributed and synced training
    if self.train_config.sync_replicas and run_config.num_worker_replicas > 1:
      logging.info('sync_replicas: num_worker_replias = %d' %
                   run_config.num_worker_replicas)
      if pai_util.is_on_pai():
        optimizer = tf.train.SyncReplicasOptimizer(
            optimizer,
            replicas_to_aggregate=run_config.num_worker_replicas,
            total_num_replicas=run_config.num_worker_replicas,
            sparse_accumulator_type=self.train_config.sparse_accumulator_type)
      else:
        optimizer = sync_replicas_optimizer.SyncReplicasOptimizer(
            optimizer,
            replicas_to_aggregate=run_config.num_worker_replicas,
            total_num_replicas=run_config.num_worker_replicas)
      hooks.append(
          optimizer.make_session_run_hook(run_config.is_chief, num_tokens=0))

    # add barrier for no strategy case
    if run_config.num_worker_replicas > 1 and \
       self.train_config.train_distribute == DistributionStrategy.NoStrategy:
      hooks.append(
          estimator_utils.ExitBarrierHook(run_config.num_worker_replicas,
                                          run_config.is_chief, self.model_dir))

    if self.export_config.enable_early_stop:
      eval_dir = os.path.join(self._model_dir, 'eval_val')
      logging.info('will use early stop, eval_events_dir=%s' % eval_dir)
      if self.export_config.HasField('early_stop_func'):
        hooks.append(
            custom_early_stop_hook(
                self,
                eval_dir=eval_dir,
                custom_stop_func=self.export_config.early_stop_func,
                custom_stop_func_params=self.export_config.early_stop_params))
      elif self.export_config.metric_bigger:
        hooks.append(
            stop_if_no_increase_hook(
                self,
                self.export_config.best_exporter_metric,
                self.export_config.max_check_steps,
                eval_dir=eval_dir))
      else:
        hooks.append(
            stop_if_no_decrease_hook(
                self,
                self.export_config.best_exporter_metric,
                self.export_config.max_check_steps,
                eval_dir=eval_dir))

    if self.train_config.enable_oss_stop_signal:
      hooks.append(oss_stop_hook(self))

    if self.train_config.HasField('dead_line'):
      hooks.append(deadline_stop_hook(self, self.train_config.dead_line))

    summaries = ['global_gradient_norm']
    if self.train_config.summary_model_vars:
      summaries.extend(['gradient_norm', 'gradients'])

    gradient_clipping_by_norm = self.train_config.gradient_clipping_by_norm
    if gradient_clipping_by_norm <= 0:
      gradient_clipping_by_norm = None

    gradient_multipliers = None
    if self.train_config.optimizer_config[0].HasField(
        'embedding_learning_rate_multiplier'):
      gradient_multipliers = {
          var: self.train_config.optimizer_config[0]
          .embedding_learning_rate_multiplier
          for var in tf.trainable_variables()
          if 'embedding_weights:' in var.name or
          '/embedding_weights/part_' in var.name
      }

    # optimize loss
    # colocate_gradients_with_ops=True means to compute gradients
    # on the same device on which op is processes in forward process
    all_train_vars = []
    if len(self.train_config.freeze_gradient) > 0:
      for one_var in tf.trainable_variables():
        is_freeze = False
        for x in self.train_config.freeze_gradient:
          if re.search(x, one_var.name) is not None:
            logging.info('will freeze gradients of %s' % one_var.name)
            is_freeze = True
            break
        if not is_freeze:
          all_train_vars.append(one_var)
    else:
      all_train_vars = tf.trainable_variables()

    if self.embedding_parallel:
      logging.info('embedding_parallel is enabled')

    train_op = optimizers.optimize_loss(
        loss=loss,
        global_step=tf.train.get_global_step(),
        learning_rate=None,
        clip_gradients=gradient_clipping_by_norm,
        optimizer=optimizer,
        gradient_multipliers=gradient_multipliers,
        variables=all_train_vars,
        summaries=summaries,
        colocate_gradients_with_ops=True,
        not_apply_grad_after_first_step=run_config.is_chief and
        self._pipeline_config.data_config.chief_redundant,
        name='',  # Preventing scope prefix on all variables.
        incr_save=(self.incr_save_config is not None),
        embedding_parallel=self.embedding_parallel)

    # online evaluation
    metric_update_op_dict = None
    if self.eval_config.eval_online:
      metric_update_op_dict = {}
      metric_dict = model.build_metric_graph(self.eval_config)
      for k, v in metric_dict.items():
        metric_update_op_dict['%s/batch' % k] = v[1]
        if isinstance(v[1], tf.Tensor):
          tf.summary.scalar('%s/batch' % k, v[1])
      train_op = tf.group([train_op] + list(metric_update_op_dict.values()))
      if estimator_utils.is_chief():
        hooks.append(
            estimator_utils.OnlineEvaluationHook(
                metric_dict=metric_dict, output_dir=self.model_dir))

    if self.train_config.HasField('fine_tune_checkpoint'):
      fine_tune_ckpt = self.train_config.fine_tune_checkpoint
      logging.warning('will restore from %s' % fine_tune_ckpt)
      fine_tune_ckpt_var_map = self.train_config.fine_tune_ckpt_var_map
      force_restore = self.train_config.force_restore_shape_compatible
      restore_hook = model.restore(
          fine_tune_ckpt,
          include_global_step=False,
          ckpt_var_map_path=fine_tune_ckpt_var_map,
          force_restore_shape_compatible=force_restore)
      if restore_hook is not None:
        hooks.append(restore_hook)

    # logging
    logging_dict = OrderedDict()
    logging_dict['step'] = tf.train.get_global_step()
    logging_dict['lr'] = learning_rate[0]
    logging_dict.update(loss_dict)
    if metric_update_op_dict is not None:
      logging_dict.update(metric_update_op_dict)

    log_step_count_steps = self.train_config.log_step_count_steps
    logging_hook = basic_session_run_hooks.LoggingTensorHook(
        logging_dict,
        every_n_iter=log_step_count_steps,
        formatter=estimator_utils.tensor_log_format_func)
    hooks.append(logging_hook)

    if self.train_config.train_distribute in [
        DistributionStrategy.CollectiveAllReduceStrategy,
        DistributionStrategy.MirroredStrategy,
        DistributionStrategy.MultiWorkerMirroredStrategy
    ]:
      # for multi worker strategy, we could not replace the
      # inner CheckpointSaverHook, so just use it.
      scaffold = tf.train.Scaffold()
    else:
      var_list = (
          tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
          tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))

      # exclude data_offset_var
      var_list = [x for x in var_list if x != data_offset_var]
      # early_stop flag will not be saved in checkpoint
      # and could not be restored from checkpoint
      early_stop_var = find_early_stop_var(var_list)
      var_list = [x for x in var_list if x != early_stop_var]

      initialize_var_list = [
          x for x in var_list if 'WorkQueue' not in str(type(x))
      ]

      # incompatiable shape restore will not be saved in checkpoint
      # but must be able to restore from checkpoint
      incompatiable_shape_restore = tf.get_collection('T_E_M_P_RESTROE')

      local_init_ops = [tf.train.Scaffold.default_local_init_op()]
      if data_offset_var is not None and estimator_utils.is_chief():
        local_init_ops.append(tf.initializers.variables([data_offset_var]))
      if early_stop_var is not None and estimator_utils.is_chief():
        local_init_ops.append(tf.initializers.variables([early_stop_var]))
      if len(incompatiable_shape_restore) > 0:
        local_init_ops.append(
            tf.initializers.variables(incompatiable_shape_restore))

      scaffold = tf.train.Scaffold(
          saver=self.saver_cls(
              var_list=var_list,
              sharded=True,
              max_to_keep=self.train_config.keep_checkpoint_max,
              save_relative_paths=True),
          local_init_op=tf.group(local_init_ops),
          ready_for_local_init_op=tf.report_uninitialized_variables(
              var_list=initialize_var_list))
      # saver hook
      saver_hook = estimator_utils.CheckpointSaverHook(
          checkpoint_dir=self.model_dir,
          save_secs=self._config.save_checkpoints_secs,
          save_steps=self._config.save_checkpoints_steps,
          scaffold=scaffold,
          write_graph=self.train_config.write_graph,
          data_offset_var=data_offset_var,
          increment_save_config=self.incr_save_config)
      if estimator_utils.is_chief() or self.embedding_parallel:
        hooks.append(saver_hook)
      if estimator_utils.is_chief():
        hooks.append(
            basic_session_run_hooks.StepCounterHook(
                every_n_steps=log_step_count_steps, output_dir=self.model_dir))

    # profiling hook
    if self.train_config.is_profiling and estimator_utils.is_chief():
      profile_hook = tf.train.ProfilerHook(
          save_steps=log_step_count_steps, output_dir=self.model_dir)
      hooks.append(profile_hook)

    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        predictions=predict_dict,
        train_op=train_op,
        scaffold=scaffold,
        training_hooks=hooks)