in adanet/core/iteration.py [0:0]
def build_iteration(self,
base_global_step,
iteration_number,
ensemble_candidates,
subnetwork_builders,
features,
mode,
config,
labels=None,
previous_ensemble_summary=None,
rebuilding=False,
rebuilding_ensembler_name=None,
best_ensemble_index_override=None,
previous_iteration=None):
"""Builds and returns AdaNet iteration t.
This method uses the generated the candidate subnetworks given the ensemble
at iteration t-1 and creates graph operations to train them. The returned
`_Iteration` tracks the training of all candidates to know when the
iteration is over, and tracks the best candidate's predictions and loss, as
defined by lowest complexity-regularized loss on the train set.
Args:
base_global_step: Integer global step at the beginning of this iteration.
iteration_number: Integer iteration number.
ensemble_candidates: Iterable of `adanet.ensemble.Candidate` instances.
subnetwork_builders: A list of `Builders` for adding ` Subnetworks` to the
graph. Each subnetwork is then wrapped in a `_Candidate` to train.
features: Dictionary of `Tensor` objects keyed by feature name.
mode: Defines whether this is training, evaluation or prediction. See
`ModeKeys`.
config: The `tf.estimator.RunConfig` to use this iteration.
labels: `Tensor` of labels. Can be `None`.
previous_ensemble_summary: The `adanet.Summary` for the previous ensemble.
rebuilding: Boolean whether the iteration is being rebuilt only to restore
the previous best subnetworks and ensembles.
rebuilding_ensembler_name: Optional ensembler to restrict to, only
relevant when rebuilding is set as True.
best_ensemble_index_override: Integer index to identify the best ensemble
candidate instead of computing the best ensemble index dynamically
conditional on the ensemble AdaNet losses.
previous_iteration: The iteration occuring before this one or None if this
is the first iteration.
Returns:
An _Iteration instance.
Raises:
ValueError: If subnetwork_builders is empty.
ValueError: If two subnetworks share the same name.
ValueError: If two ensembles share the same name.
"""
self._placement_strategy.config = config
logging.info("%s iteration %s", "Rebuilding" if rebuilding else "Building",
iteration_number)
if not subnetwork_builders:
raise ValueError("Each iteration must have at least one Builder.")
# TODO: Consider moving builder mode logic to ensemble_builder.py.
builder_mode = mode
if rebuilding:
# Build the subnetworks and ensembles in EVAL mode by default. This way
# their outputs aren't affected by dropout etc.
builder_mode = tf.estimator.ModeKeys.EVAL
if mode == tf.estimator.ModeKeys.PREDICT:
builder_mode = mode
# Only replicate in training mode when the user requests it.
if self._replicate_ensemble_in_training and (
mode == tf.estimator.ModeKeys.TRAIN):
builder_mode = mode
features, labels = self._check_numerics(features, labels)
replay_indices_for_all = {}
training = mode == tf.estimator.ModeKeys.TRAIN
skip_summaries = mode == tf.estimator.ModeKeys.PREDICT or rebuilding
with tf_compat.v1.variable_scope("iteration_{}".format(iteration_number)):
seen_builder_names = {}
candidates = []
summaries = []
subnetwork_reports = {}
previous_ensemble = None
previous_ensemble_spec = None
previous_iteration_checkpoint = None
if previous_iteration:
previous_iteration_checkpoint = previous_iteration.checkpoint
previous_best_candidate = previous_iteration.candidates[-1]
previous_ensemble_spec = previous_best_candidate.ensemble_spec
previous_ensemble = previous_ensemble_spec.ensemble
replay_indices_for_all[len(candidates)] = copy.copy(
previous_ensemble_spec.architecture.replay_indices)
# Include previous best subnetwork as a candidate so that its
# predictions are returned until a new candidate outperforms.
seen_builder_names = {previous_ensemble_spec.name: True}
candidates.append(previous_best_candidate)
if self._enable_ensemble_summaries:
summaries.append(previous_ensemble_summary)
# Generate subnetwork reports.
if (self._enable_subnetwork_reports and
mode == tf.estimator.ModeKeys.EVAL):
metrics = previous_ensemble_spec.eval_metrics.eval_metrics_ops()
subnetwork_report = subnetwork.Report(
hparams={},
attributes={},
metrics=metrics,
)
subnetwork_report.metrics["adanet_loss"] = tf_compat.v1.metrics.mean(
previous_ensemble_spec.adanet_loss)
subnetwork_reports["previous_ensemble"] = subnetwork_report
for subnetwork_builder in subnetwork_builders:
if subnetwork_builder.name in seen_builder_names:
raise ValueError("Two subnetworks have the same name '{}'".format(
subnetwork_builder.name))
seen_builder_names[subnetwork_builder.name] = True
subnetwork_specs = []
num_subnetworks = len(subnetwork_builders)
skip_summary = skip_summaries or not self._enable_subnetwork_summaries
for i, subnetwork_builder in enumerate(subnetwork_builders):
if not self._placement_strategy.should_build_subnetwork(
num_subnetworks, i) and not rebuilding:
continue
with self._placement_strategy.subnetwork_devices(num_subnetworks, i):
subnetwork_name = "t{}_{}".format(iteration_number,
subnetwork_builder.name)
subnetwork_summary = self._summary_maker(
namespace="subnetwork",
scope=subnetwork_name,
skip_summary=skip_summary)
if not skip_summary:
summaries.append(subnetwork_summary)
logging.info("%s subnetwork '%s'",
"Rebuilding" if rebuilding else "Building",
subnetwork_builder.name)
subnetwork_spec = self._subnetwork_manager.build_subnetwork_spec(
name=subnetwork_name,
subnetwork_builder=subnetwork_builder,
summary=subnetwork_summary,
features=features,
mode=builder_mode,
labels=labels,
previous_ensemble=previous_ensemble,
config=config)
subnetwork_specs.append(subnetwork_spec)
# Workers that don't build ensembles need a dummy candidate in order
# to train the subnetwork.
# Because only ensembles can be considered candidates, we need to
# convert the subnetwork into a dummy ensemble and subsequently a
# dummy candidate. However, this dummy candidate is never considered a
# true candidate during candidate evaluation and selection.
# TODO: Eliminate need for candidates.
if not self._placement_strategy.should_build_ensemble(
num_subnetworks) and not rebuilding:
candidates.append(
self._create_dummy_candidate(subnetwork_spec,
subnetwork_builders,
subnetwork_summary, training))
# Generate subnetwork reports.
if (self._enable_subnetwork_reports and
mode != tf.estimator.ModeKeys.PREDICT):
subnetwork_report = subnetwork_builder.build_subnetwork_report()
if not subnetwork_report:
subnetwork_report = subnetwork.Report(
hparams={}, attributes={}, metrics={})
metrics = subnetwork_spec.eval_metrics.eval_metrics_ops()
for metric_name in sorted(metrics):
metric = metrics[metric_name]
subnetwork_report.metrics[metric_name] = metric
subnetwork_reports[subnetwork_builder.name] = subnetwork_report
# Create (ensemble_candidate*ensembler) ensembles.
skip_summary = skip_summaries or not self._enable_ensemble_summaries
seen_ensemble_names = {}
for ensembler in self._ensemblers:
if rebuilding and rebuilding_ensembler_name and (
ensembler.name != rebuilding_ensembler_name):
continue
for ensemble_candidate in ensemble_candidates:
if not self._placement_strategy.should_build_ensemble(
num_subnetworks) and not rebuilding:
continue
ensemble_name = "t{}_{}_{}".format(iteration_number,
ensemble_candidate.name,
ensembler.name)
if ensemble_name in seen_ensemble_names:
raise ValueError(
"Two ensembles have the same name '{}'".format(ensemble_name))
seen_ensemble_names[ensemble_name] = True
summary = self._summary_maker(
namespace="ensemble",
scope=ensemble_name,
skip_summary=skip_summary)
if not skip_summary:
summaries.append(summary)
ensemble_spec = self._ensemble_builder.build_ensemble_spec(
name=ensemble_name,
candidate=ensemble_candidate,
ensembler=ensembler,
subnetwork_specs=subnetwork_specs,
summary=summary,
features=features,
mode=builder_mode,
iteration_number=iteration_number,
labels=labels,
my_ensemble_index=len(candidates),
previous_ensemble_spec=previous_ensemble_spec,
previous_iteration_checkpoint=previous_iteration_checkpoint)
# TODO: Eliminate need for candidates.
candidate = self._candidate_builder.build_candidate(
ensemble_spec=ensemble_spec,
training=training,
summary=summary,
rebuilding=rebuilding)
replay_indices_for_all[len(candidates)] = copy.copy(
ensemble_spec.architecture.replay_indices)
candidates.append(candidate)
# TODO: Move adanet_loss from subnetwork report to a new
# ensemble report, since the adanet_loss is associated with an
# ensemble, and only when using a ComplexityRegularizedEnsemblers.
# Keep adanet_loss in subnetwork report for backwards compatibility.
if len(ensemble_candidates) != len(subnetwork_builders):
continue
if len(ensemble_candidate.subnetwork_builders) > 1:
continue
if mode == tf.estimator.ModeKeys.PREDICT:
continue
builder_name = ensemble_candidate.subnetwork_builders[0].name
if self._enable_subnetwork_reports:
subnetwork_reports[builder_name].metrics[
"adanet_loss"] = tf_compat.v1.metrics.mean(
ensemble_spec.adanet_loss)
# Dynamically select the outputs of best candidate.
best_candidate_index = self._best_candidate_index(
candidates, best_ensemble_index_override)
best_predictions = self._best_predictions(candidates,
best_candidate_index)
best_loss = self._best_loss(candidates, best_candidate_index, mode)
best_export_outputs = self._best_export_outputs(candidates,
best_candidate_index,
mode, best_predictions)
train_manager_dir = os.path.join(config.model_dir, "train_manager",
"t{}".format(iteration_number))
train_manager, training_chief_hooks, training_hooks = self._create_hooks(
base_global_step, subnetwork_specs, candidates, num_subnetworks,
rebuilding, train_manager_dir, config.is_chief)
local_init_ops = []
if previous_ensemble_spec:
for s in previous_ensemble_spec.ensemble.subnetworks:
if s.local_init_ops:
local_init_ops.extend(s.local_init_ops)
for subnetwork_spec in subnetwork_specs:
if (subnetwork_spec and subnetwork_spec.subnetwork and
subnetwork_spec.subnetwork.local_init_ops):
local_init_ops.extend(subnetwork_spec.subnetwork.local_init_ops)
summary = self._summary_maker(
namespace=None, scope=None, skip_summary=skip_summaries)
summaries.append(summary)
with summary.current_scope():
summary.scalar("iteration/adanet/iteration", iteration_number)
if best_loss is not None:
summary.scalar("loss", best_loss)
iteration_metrics = _IterationMetrics(iteration_number, candidates,
subnetwork_specs, self._use_tpu,
replay_indices_for_all)
checkpoint = self._make_checkpoint(candidates, subnetwork_specs,
iteration_number, previous_iteration)
if self._use_tpu:
estimator_spec = tf_compat.v1.estimator.tpu.TPUEstimatorSpec(
mode=mode,
predictions=best_predictions,
loss=best_loss,
train_op=self._create_tpu_train_op(base_global_step,
subnetwork_specs, candidates,
mode, num_subnetworks, config),
eval_metrics=iteration_metrics.best_eval_metrics_tuple(
best_candidate_index, mode),
export_outputs=best_export_outputs,
training_hooks=training_hooks,
scaffold_fn=self._get_scaffold_fn(local_init_ops))
else:
estimator_spec = tf.estimator.EstimatorSpec(
mode=mode,
predictions=best_predictions,
loss=best_loss,
# All training happens in hooks so we don't need a train op.
train_op=tf.no_op() if training else None,
eval_metric_ops=iteration_metrics.best_eval_metric_ops(
best_candidate_index, mode),
export_outputs=best_export_outputs,
training_chief_hooks=training_chief_hooks,
training_hooks=training_hooks,
scaffold=self._get_scaffold_fn(local_init_ops)())
return _Iteration(
number=iteration_number,
candidates=candidates,
subnetwork_specs=subnetwork_specs,
estimator_spec=estimator_spec,
best_candidate_index=best_candidate_index,
summaries=summaries,
train_manager=train_manager,
subnetwork_reports=subnetwork_reports,
checkpoint=checkpoint,
previous_iteration=previous_iteration)