in adanet/autoensemble/common.py [0:0]
def build_subnetwork(self,
features,
labels,
logits_dimension,
training,
iteration_step,
summary,
previous_ensemble,
config=None):
# We don't need an EVAL mode since AdaNet takes care of evaluation for us.
subestimator = self._subestimator(config)
mode = tf.estimator.ModeKeys.PREDICT
if training and not subestimator.prediction_only:
mode = tf.estimator.ModeKeys.TRAIN
# Call in template to ensure that variables are created once and reused.
call_model_fn_template = tf.compat.v1.make_template("model_fn",
self._call_model_fn)
subestimator_features, subestimator_labels = features, labels
local_init_ops = []
if training and subestimator.train_input_fn:
# TODO: Consider tensorflow_estimator/python/estimator/util.py.
inputs = subestimator.train_input_fn()
if isinstance(inputs, (tf_compat.DatasetV1, tf_compat.DatasetV2)):
subestimator_features, subestimator_labels = (
tf_compat.make_one_shot_iterator(inputs).get_next())
else:
subestimator_features, subestimator_labels = inputs
# Construct subnetwork graph first because of dependencies on scope.
_, _, bagging_train_op_spec, sub_local_init_op = call_model_fn_template(
subestimator, subestimator_features, subestimator_labels, mode,
summary)
# Graph for ensemble learning gets model_fn_1 for scope.
logits, last_layer, _, ensemble_local_init_op = call_model_fn_template(
subestimator, features, labels, mode, summary)
if sub_local_init_op:
local_init_ops.append(sub_local_init_op)
if ensemble_local_init_op:
local_init_ops.append(ensemble_local_init_op)
# Run train op in a hook so that exceptions can be intercepted by the
# AdaNet framework instead of the Estimator's monitored training session.
hooks = bagging_train_op_spec.hooks + (_SecondaryTrainOpRunnerHook(
bagging_train_op_spec.train_op),)
train_op_spec = subnetwork_lib.TrainOpSpec(
train_op=tf.no_op(),
chief_hooks=bagging_train_op_spec.chief_hooks,
hooks=hooks)
else:
logits, last_layer, train_op_spec, local_init_op = call_model_fn_template(
subestimator, features, labels, mode, summary)
if local_init_op:
local_init_ops.append(local_init_op)
# TODO: Replace with variance complexity measure.
complexity = tf.constant(0.)
return subnetwork_lib.Subnetwork(
logits=logits,
last_layer=last_layer,
shared={"train_op": train_op_spec},
complexity=complexity,
local_init_ops=local_init_ops)