def build_subnetwork()

in adanet/autoensemble/common.py [0:0]


  def build_subnetwork(self,
                       features,
                       labels,
                       logits_dimension,
                       training,
                       iteration_step,
                       summary,
                       previous_ensemble,
                       config=None):
    # We don't need an EVAL mode since AdaNet takes care of evaluation for us.
    subestimator = self._subestimator(config)
    mode = tf.estimator.ModeKeys.PREDICT
    if training and not subestimator.prediction_only:
      mode = tf.estimator.ModeKeys.TRAIN

    # Call in template to ensure that variables are created once and reused.
    call_model_fn_template = tf.compat.v1.make_template("model_fn",
                                                        self._call_model_fn)
    subestimator_features, subestimator_labels = features, labels
    local_init_ops = []
    if training and subestimator.train_input_fn:
      # TODO: Consider tensorflow_estimator/python/estimator/util.py.
      inputs = subestimator.train_input_fn()
      if isinstance(inputs, (tf_compat.DatasetV1, tf_compat.DatasetV2)):
        subestimator_features, subestimator_labels = (
            tf_compat.make_one_shot_iterator(inputs).get_next())
      else:
        subestimator_features, subestimator_labels = inputs

      # Construct subnetwork graph first because of dependencies on scope.
      _, _, bagging_train_op_spec, sub_local_init_op = call_model_fn_template(
          subestimator, subestimator_features, subestimator_labels, mode,
          summary)
      # Graph for ensemble learning gets model_fn_1 for scope.
      logits, last_layer, _, ensemble_local_init_op = call_model_fn_template(
          subestimator, features, labels, mode, summary)

      if sub_local_init_op:
        local_init_ops.append(sub_local_init_op)
      if ensemble_local_init_op:
        local_init_ops.append(ensemble_local_init_op)

      # Run train op in a hook so that exceptions can be intercepted by the
      # AdaNet framework instead of the Estimator's monitored training session.
      hooks = bagging_train_op_spec.hooks + (_SecondaryTrainOpRunnerHook(
          bagging_train_op_spec.train_op),)
      train_op_spec = subnetwork_lib.TrainOpSpec(
          train_op=tf.no_op(),
          chief_hooks=bagging_train_op_spec.chief_hooks,
          hooks=hooks)
    else:
      logits, last_layer, train_op_spec, local_init_op = call_model_fn_template(
          subestimator, features, labels, mode, summary)
      if local_init_op:
        local_init_ops.append(local_init_op)

    # TODO: Replace with variance complexity measure.
    complexity = tf.constant(0.)
    return subnetwork_lib.Subnetwork(
        logits=logits,
        last_layer=last_layer,
        shared={"train_op": train_op_spec},
        complexity=complexity,
        local_init_ops=local_init_ops)