in adanet/core/ensemble_builder.py [0:0]
def build_subnetwork_spec(self,
name,
subnetwork_builder,
summary,
features,
mode,
labels=None,
previous_ensemble=None,
config=None):
"""Builds a `_SubnetworkSpec` from the given `adanet.subnetwork.Builder`.
Args:
name: String name of the subnetwork.
subnetwork_builder: A `adanet.Builder` instance which defines how to train
the subnetwork and ensemble mixture weights.
summary: A `_ScopedSummary` instance for recording ensemble summaries.
features: Input `dict` of `Tensor` objects.
mode: Estimator's `ModeKeys`.
labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
(for multi-head). Can be `None`.
previous_ensemble: The previous `Ensemble` from iteration t-1. Used for
creating the subnetwork train_op.
config: The `tf.estimator.RunConfig` to use this iteration.
Returns:
An new `EnsembleSpec` instance with the `Subnetwork` appended.
"""
old_vars = _get_current_vars()
with tf_compat.v1.variable_scope("subnetwork_{}".format(name)):
step = tf_compat.v1.get_variable(
"step",
shape=[],
initializer=tf_compat.v1.zeros_initializer(),
trainable=False,
dtype=tf.int64)
# Convert to tensor so that users cannot mutate it.
step_tensor = tf.convert_to_tensor(value=step)
with summary.current_scope():
summary.scalar("iteration_step/adanet/iteration_step", step_tensor)
if config:
subnetwork_config = config.replace(
model_dir=os.path.join(config.model_dir, "assets", name))
else:
subnetwork_config = tf.estimator.RunConfig(
session_config=tf.compat.v1.ConfigProto(
gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)))
build_subnetwork = functools.partial(
subnetwork_builder.build_subnetwork,
features=features,
logits_dimension=self._head.logits_dimension,
training=mode == tf.estimator.ModeKeys.TRAIN,
iteration_step=step_tensor,
summary=summary,
previous_ensemble=previous_ensemble)
# Check which args are in the implemented build_subnetwork method
# signature for backwards compatibility.
# Calling low level getargs for py_2_and_3 compatibility.
defined_args = inspect.getargs(
subnetwork_builder.build_subnetwork.__code__).args
if "labels" in defined_args:
build_subnetwork = functools.partial(build_subnetwork, labels=labels)
if "config" in defined_args:
build_subnetwork = functools.partial(
build_subnetwork, config=subnetwork_config)
subnetwork_scope = tf_compat.v1.get_variable_scope()
with summary.current_scope(), _monkey_patch_context(
iteration_step_scope=subnetwork_scope,
scoped_summary=summary,
trainable_vars=[]):
subnetwork = build_subnetwork()
subnetwork_var_list = _get_current_vars(diffbase=old_vars)["trainable"]
estimator_spec = _create_estimator_spec(self._head, features, labels,
mode, subnetwork.logits,
self._use_tpu)
subnetwork_metrics = _SubnetworkMetrics(self._use_tpu)
if mode == tf.estimator.ModeKeys.EVAL:
subnetwork_metrics.create_eval_metrics(
features=features,
labels=labels,
estimator_spec=estimator_spec,
metric_fn=self._metric_fn)
if mode == tf.estimator.ModeKeys.TRAIN:
with summary.current_scope():
summary.scalar("loss", estimator_spec.loss)
# Create train ops for training subnetworks and ensembles.
train_op = None
if mode == tf.estimator.ModeKeys.TRAIN and subnetwork_builder:
with summary.current_scope(), _monkey_patch_context(
iteration_step_scope=subnetwork_scope,
scoped_summary=summary,
trainable_vars=subnetwork_var_list):
train_op = _to_train_op_spec(
subnetwork_builder.build_subnetwork_train_op(
subnetwork=subnetwork,
loss=estimator_spec.loss,
var_list=subnetwork_var_list,
labels=labels,
iteration_step=step_tensor,
summary=summary,
previous_ensemble=previous_ensemble))
new_vars = _get_current_vars(diffbase=old_vars)
# Sort our dictionary by key to remove non-determinism of variable order.
new_vars = collections.OrderedDict(sorted(new_vars.items()))
# Combine all trainable, global and savable variables into a single list.
subnetwork_variables = sum(new_vars.values(), []) + [step]
return _SubnetworkSpec(
name=name,
subnetwork=subnetwork,
builder=subnetwork_builder,
predictions=estimator_spec.predictions,
variables=subnetwork_variables,
loss=estimator_spec.loss,
step=step,
train_op=train_op,
eval_metrics=subnetwork_metrics,
asset_dir=subnetwork_config.model_dir)