in adanet/core/ensemble_builder.py [0:0]
def build_ensemble_spec(self,
name,
candidate,
ensembler,
subnetwork_specs,
summary,
features,
mode,
iteration_number,
labels,
my_ensemble_index,
previous_ensemble_spec,
previous_iteration_checkpoint):
"""Builds an `_EnsembleSpec` with the given `adanet.ensemble.Candidate`.
Args:
name: The string name of the ensemble. Typically the name of the builder
that returned the given `Subnetwork`.
candidate: The `adanet.ensemble.Candidate` for this spec.
ensembler: The :class:`adanet.ensemble.Ensembler` to use to ensemble a
group of subnetworks.
subnetwork_specs: Iterable of `_SubnetworkSpecs` for this iteration.
summary: A `_ScopedSummary` instance for recording ensemble summaries.
features: Input `dict` of `Tensor` objects.
mode: Estimator `ModeKeys` indicating training, evaluation, or inference.
iteration_number: Integer current iteration number.
labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
(for multi-head).
my_ensemble_index: An integer holding the index of the ensemble in the
candidates list of AdaNet.
previous_ensemble_spec: Link the rest of the `_EnsembleSpec` from
iteration t-1. Used for creating the subnetwork train_op.
previous_iteration_checkpoint: `tf.train.Checkpoint` for iteration t-1.
Returns:
An `_EnsembleSpec` instance.
"""
with tf_compat.v1.variable_scope("ensemble_{}".format(name)):
step = tf_compat.v1.get_variable(
"step",
shape=[],
initializer=tf_compat.v1.zeros_initializer(),
trainable=False,
dtype=tf.int64)
# Convert to tensor so that users cannot mutate it.
step_tensor = tf.convert_to_tensor(value=step)
with summary.current_scope():
summary.scalar("iteration_step/adanet/iteration_step", step_tensor)
replay_indices = []
if previous_ensemble_spec:
replay_indices = copy.copy(
previous_ensemble_spec.architecture.replay_indices)
if my_ensemble_index is not None:
replay_indices.append(my_ensemble_index)
architecture = _Architecture(
candidate.name, ensembler.name, replay_indices=replay_indices)
previous_subnetworks = []
previous_subnetwork_specs = []
subnetwork_builders = []
previous_ensemble = None
if previous_ensemble_spec:
previous_ensemble = previous_ensemble_spec.ensemble
previous_architecture = previous_ensemble_spec.architecture
keep_indices = range(len(previous_ensemble.subnetworks))
if len(candidate.subnetwork_builders) == 1 and previous_ensemble:
# Prune previous ensemble according to the subnetwork.Builder for
# backwards compatibility.
subnetwork_builder = candidate.subnetwork_builders[0]
prune_previous_ensemble = getattr(subnetwork_builder,
"prune_previous_ensemble", None)
if callable(prune_previous_ensemble):
logging.warn(
"Using an `adanet.subnetwork.Builder#prune_previous_ensemble` "
"is deprecated. Please use a custom `adanet.ensemble.Strategy` "
"instead.")
keep_indices = prune_previous_ensemble(previous_ensemble)
for i, builder in enumerate(previous_ensemble_spec.subnetwork_builders):
if i not in keep_indices:
continue
if builder not in candidate.previous_ensemble_subnetwork_builders:
continue
previous_subnetworks.append(previous_ensemble.subnetworks[i])
previous_subnetwork_specs.append(
previous_ensemble_spec.subnetwork_specs[i])
subnetwork_builders.append(builder)
architecture.add_subnetwork(*previous_architecture.subnetworks[i])
for builder in candidate.subnetwork_builders:
architecture.add_subnetwork(iteration_number, builder.name)
subnetwork_builders.append(builder)
subnetwork_spec_map = {s.builder.name: s for s in subnetwork_specs}
relevant_subnetwork_specs = [
subnetwork_spec_map[s.name] for s in candidate.subnetwork_builders
]
ensemble_scope = tf_compat.v1.get_variable_scope()
old_vars = _get_current_vars()
with summary.current_scope(), _monkey_patch_context(
iteration_step_scope=ensemble_scope,
scoped_summary=summary,
trainable_vars=[]):
ensemble = ensembler.build_ensemble(
subnetworks=[s.subnetwork for s in relevant_subnetwork_specs],
previous_ensemble_subnetworks=previous_subnetworks,
features=features,
labels=labels,
logits_dimension=self._head.logits_dimension,
training=mode == tf.estimator.ModeKeys.TRAIN,
iteration_step=step_tensor,
summary=summary,
previous_ensemble=previous_ensemble,
previous_iteration_checkpoint=previous_iteration_checkpoint)
estimator_spec = _create_estimator_spec(self._head, features, labels,
mode, ensemble.logits,
self._use_tpu)
ensemble_loss = estimator_spec.loss
adanet_loss = None
if mode != tf.estimator.ModeKeys.PREDICT:
adanet_loss = estimator_spec.loss
# Add ensembler specific loss
if isinstance(ensemble, ensemble_lib.ComplexityRegularized):
adanet_loss += ensemble.complexity_regularization
predictions = estimator_spec.predictions
export_outputs = estimator_spec.export_outputs
if (self._export_subnetwork_logits and
export_outputs and subnetwork_spec_map):
first_subnetwork_logits = list(
subnetwork_spec_map.values())[0].subnetwork.logits
if isinstance(first_subnetwork_logits, dict):
for head_name in first_subnetwork_logits.keys():
subnetwork_logits = {
subnetwork_name: subnetwork_spec.subnetwork.logits[head_name]
for subnetwork_name, subnetwork_spec in
subnetwork_spec_map.items()
}
export_outputs.update({
"{}_{}".format(
_EnsembleBuilder._SUBNETWORK_LOGITS_EXPORT_SIGNATURE,
head_name):
tf.estimator.export.PredictOutput(subnetwork_logits)
})
else:
subnetwork_logits = {
subnetwork_name: subnetwork_spec.subnetwork.logits for
subnetwork_name, subnetwork_spec in subnetwork_spec_map.items()
}
export_outputs.update({
_EnsembleBuilder._SUBNETWORK_LOGITS_EXPORT_SIGNATURE:
tf.estimator.export.PredictOutput(subnetwork_logits)
})
if (self._export_subnetwork_last_layer and export_outputs and
subnetwork_spec_map and
list(subnetwork_spec_map.values())[0].subnetwork.last_layer is
not None):
first_subnetwork_last_layer = list(
subnetwork_spec_map.values())[0].subnetwork.last_layer
if isinstance(first_subnetwork_last_layer, dict):
for head_name in first_subnetwork_last_layer.keys():
subnetwork_last_layer = {
subnetwork_name:
subnetwork_spec.subnetwork.last_layer[head_name] for
subnetwork_name, subnetwork_spec in subnetwork_spec_map.items()
}
export_outputs.update({
"{}_{}".format(
_EnsembleBuilder._SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE,
head_name):
tf.estimator.export.PredictOutput(subnetwork_last_layer)
})
else:
subnetwork_last_layer = {
subnetwork_name: subnetwork_spec.subnetwork.last_layer for
subnetwork_name, subnetwork_spec in subnetwork_spec_map.items()
}
export_outputs.update({
_EnsembleBuilder._SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE:
tf.estimator.export.PredictOutput(subnetwork_last_layer)
})
if ensemble.predictions and predictions:
predictions.update(ensemble.predictions)
if ensemble.predictions and export_outputs:
export_outputs.update({
k: tf.estimator.export.PredictOutput(v)
for k, v in ensemble.predictions.items()
})
ensemble_metrics = _EnsembleMetrics(use_tpu=self._use_tpu)
if mode == tf.estimator.ModeKeys.EVAL:
ensemble_metrics.create_eval_metrics(
features=features,
labels=labels,
estimator_spec=estimator_spec,
metric_fn=self._metric_fn,
architecture=architecture)
if mode == tf.estimator.ModeKeys.TRAIN:
with summary.current_scope():
summary.scalar("loss", estimator_spec.loss)
ensemble_trainable_vars = _get_current_vars(
diffbase=old_vars)["trainable"]
# Create train ops for training subnetworks and ensembles.
train_op = None
if mode == tf.estimator.ModeKeys.TRAIN:
# Note that these mixture weights are on top of the last_layer of the
# subnetwork constructed in TRAIN mode, which means that dropout is
# still applied when the mixture weights are being trained.
ensemble_scope = tf_compat.v1.get_variable_scope()
with tf_compat.v1.variable_scope("train_mixture_weights"):
with summary.current_scope(), _monkey_patch_context(
iteration_step_scope=ensemble_scope,
scoped_summary=summary,
trainable_vars=ensemble_trainable_vars):
# For backwards compatibility.
subnetwork_builder = candidate.subnetwork_builders[0]
old_train_op_fn = getattr(subnetwork_builder,
"build_mixture_weights_train_op", None)
if callable(old_train_op_fn):
logging.warn(
"The `build_mixture_weights_train_op` method is deprecated. "
"Please use the `Ensembler#build_train_op` instead.")
train_op = _to_train_op_spec(
subnetwork_builder.build_mixture_weights_train_op(
loss=adanet_loss,
var_list=ensemble_trainable_vars,
logits=ensemble.logits,
labels=labels,
iteration_step=step_tensor,
summary=summary))
else:
train_op = _to_train_op_spec(
ensembler.build_train_op(
ensemble=ensemble,
loss=adanet_loss,
var_list=ensemble_trainable_vars,
labels=labels,
iteration_step=step_tensor,
summary=summary,
previous_ensemble=previous_ensemble))
new_vars = _get_current_vars(diffbase=old_vars)
# Sort our dictionary by key to remove non-determinism of variable order.
new_vars = collections.OrderedDict(sorted(new_vars.items()))
# Combine all trainable, global and savable variables into a single list.
ensemble_variables = sum(new_vars.values(), []) + [step]
return _EnsembleSpec(
name=name,
architecture=architecture,
subnetwork_builders=subnetwork_builders,
subnetwork_specs=previous_subnetwork_specs + relevant_subnetwork_specs,
ensemble=ensemble,
predictions=predictions,
step=step,
variables=ensemble_variables,
loss=ensemble_loss,
adanet_loss=adanet_loss,
train_op=train_op,
eval_metrics=ensemble_metrics,
export_outputs=export_outputs)