in adanet/core/estimator.py [0:0]
def _create_iteration(self,
features,
labels,
mode,
config,
is_growing_phase,
checkpoint_path,
hooks,
best_ensemble_index_override=None):
"""Constructs the TF ops and variables for the current iteration.
Args:
features: Dictionary of `Tensor` objects keyed by feature name.
labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
(for multi-head). Can be `None`.
mode: Defines whether this is training, evaluation or prediction. See
`ModeKeys`.
config: The current `tf.estimator.RunConfig`.
is_growing_phase: Whether we are in the AdaNet graph growing phase.
checkpoint_path: Path of the checkpoint to use. When `None`, this method
uses the latest checkpoint instead.
hooks: A list of `tf.estimator.SessionRunHooks`.
best_ensemble_index_override: Integer index to identify the latest
iteration's best ensemble candidate instead of computing the best
ensemble index dynamically conditional on the ensemble AdaNet losses.
Returns:
A two-tuple of the current `_Iteration`, and list of variables from
the previous iteration for restoring during the graph growing phase.
"""
# Use the evaluation checkpoint path to get both the iteration number and
# variable values to avoid any race conditions between the first and second
# checkpoint reads.
iteration_number = self._checkpoint_iteration_number(checkpoint_path)
if mode == tf.estimator.ModeKeys.EVAL and checkpoint_path is None:
# This should only happen during some tests, so we log instead of
# asserting here.
logging.warning("There are no checkpoints available during evaluation. "
"Variables will be initialized to their defaults.")
if is_growing_phase:
assert mode == tf.estimator.ModeKeys.TRAIN
assert config.is_chief
iteration_number += 1
# Only record summaries when training.
skip_summaries = (mode != tf.estimator.ModeKeys.TRAIN or is_growing_phase)
base_global_step = 0
with tf_compat.v1.variable_scope("adanet"):
previous_iteration = None
previous_ensemble_spec = None
previous_ensemble = None
previous_ensemble_summary = None
previous_ensemble_subnetwork_builders = None
architecture = None
for i in range(iteration_number):
architecture_filename = self._architecture_filename(i)
if not tf.io.gfile.exists(architecture_filename):
continue
architecture = self._read_architecture(architecture_filename)
logging.info(
"Importing architecture from %s: [%s].", architecture_filename,
", ".join(
sorted([
"'{}:{}'".format(t, n)
for t, n in architecture.subnetworks_grouped_by_iteration
])))
base_global_step = architecture.global_step
previous_iteration = self._architecture_ensemble_spec(
architecture, i, features, mode, labels, previous_ensemble_spec,
config, previous_iteration, hooks)
previous_ensemble_spec = previous_iteration.candidates[-1].ensemble_spec
previous_ensemble = previous_ensemble_spec.ensemble
previous_ensemble_summary = self._summary_maker(
namespace="ensemble",
scope=previous_ensemble_spec.name,
skip_summary=skip_summaries)
previous_ensemble_subnetwork_builders = (
previous_ensemble_spec.subnetwork_builders)
previous_iteration_vars = None
if is_growing_phase:
# Keep track of the previous iteration variables so we can restore them
# from the previous checkpoint after growing the graph. After this line,
# any variables created will not have a matching one in the checkpoint
# until it gets overwritten.
# Note: It's not possible to just create a tf.train.Saver here since
# this code is also run on TPU, which does not support creating Savers
# inside model_fn.
previous_iteration_vars = (
tf_compat.v1.get_collection(tf_compat.v1.GraphKeys.GLOBAL_VARIABLES)
+ tf_compat.v1.get_collection(
tf_compat.v1.GraphKeys.SAVEABLE_OBJECTS))
previous_ensemble_reports, all_reports = [], []
if self._report_materializer:
previous_ensemble_reports, all_reports = (
self._collate_subnetwork_reports(iteration_number))
subnetwork_builders = self._call_generate_candidates(
previous_ensemble=previous_ensemble,
iteration_number=iteration_number,
previous_ensemble_reports=previous_ensemble_reports,
all_reports=all_reports,
config=config)
ensemble_candidates = []
for ensemble_strategy in self._ensemble_strategies:
ensemble_candidates += ensemble_strategy.generate_ensemble_candidates(
subnetwork_builders, previous_ensemble_subnetwork_builders)
current_iteration = self._iteration_builder.build_iteration(
base_global_step=base_global_step,
iteration_number=iteration_number,
ensemble_candidates=ensemble_candidates,
subnetwork_builders=subnetwork_builders,
features=features,
labels=labels,
mode=mode,
config=config,
previous_ensemble_summary=previous_ensemble_summary,
best_ensemble_index_override=best_ensemble_index_override,
previous_iteration=previous_iteration)
return current_iteration, previous_iteration_vars