in adanet/core/estimator.py [0:0]
def _get_best_ensemble_index(self,
current_iteration,
input_hooks,
checkpoint_path=None):
# type: (_Iteration, Sequence[tf_compat.SessionRunHook], Text) -> int
"""Returns the best candidate ensemble's index in this iteration.
Evaluates the ensembles using an `Evaluator` when provided. Otherwise,
it returns the index of the best candidate as defined by the `_Iteration`.
Args:
current_iteration: Current `_Iteration`.
input_hooks: List of SessionRunHooks to be included when running.
checkpoint_path: Checkpoint to use when determining the best index.
Returns:
Index of the best ensemble in the iteration's list of `_Candidates`.
"""
# AdaNet Replay.
if self._replay_config:
best_index = self._replay_config.get_best_ensemble_index(
current_iteration.number)
if best_index is not None:
return best_index
# Skip the evaluation phase when there is only one candidate subnetwork.
if len(current_iteration.candidates) == 1:
logging.info("'%s' is the only ensemble",
current_iteration.candidates[0].ensemble_spec.name)
return 0
# The zero-th index candidate at iteration t>0 is always the
# previous_ensemble.
if current_iteration.number > 0 and self._force_grow and (len(
current_iteration.candidates) == 2):
logging.info("With `force_grow` enabled, '%s' is the only ensemble",
current_iteration.candidates[1].ensemble_spec.name)
return 1
logging.info("Starting ensemble evaluation for iteration %s",
current_iteration.number)
for hook in input_hooks:
hook.begin()
with tf_compat.v1.Session(config=self.config.session_config) as sess:
init = tf.group(
tf_compat.v1.global_variables_initializer(),
tf_compat.v1.local_variables_initializer(),
tf_compat.v1.tables_initializer(),
current_iteration.estimator_spec.scaffold.local_init_op if isinstance(
current_iteration.estimator_spec,
tf.estimator.EstimatorSpec) else tf.no_op())
sess.run(init)
if self._enable_v2_checkpoint:
status = current_iteration.checkpoint.restore(checkpoint_path)
status.expect_partial() # Optional sanity checks.
status.initialize_or_restore(sess)
else:
saver = tf_compat.v1.train.Saver(sharded=True)
saver.restore(sess, checkpoint_path)
coord = tf.train.Coordinator()
for hook in input_hooks:
hook.after_create_session(sess, coord)
tf_compat.v1.train.start_queue_runners(sess=sess, coord=coord)
ensemble_metrics = []
for candidate in current_iteration.candidates:
metrics = candidate.ensemble_spec.eval_metrics.eval_metrics_ops()
metrics["adanet_loss"] = tf_compat.v1.metrics.mean(
candidate.ensemble_spec.adanet_loss)
ensemble_metrics.append(metrics)
if self._evaluator:
metric_name = self._evaluator.metric_name
metrics = self._evaluator.evaluate(sess, ensemble_metrics)
objective_fn = self._evaluator.objective_fn
else:
metric_name = "adanet_loss"
metrics = sess.run(
[c.adanet_loss for c in current_iteration.candidates])
objective_fn = np.nanargmin
values = []
for i in range(len(current_iteration.candidates)):
ensemble_name = current_iteration.candidates[i].ensemble_spec.name
values.append("{}/{} = {:.6f}".format(metric_name, ensemble_name,
metrics[i]))
logging.info("Computed ensemble metrics: %s", ", ".join(values))
if self._force_grow and current_iteration.number > 0:
logging.info(
"The `force_grow` override is enabled, so the "
"the performance of the previous ensemble will be ignored.")
# NOTE: The zero-th index candidate at iteration t>0 is always the
# previous_ensemble.
metrics = metrics[1:]
index = objective_fn(metrics) + 1
else:
index = objective_fn(metrics)
logging.info("Finished ensemble evaluation for iteration %s",
current_iteration.number)
logging.info("'%s' at index %s is the best ensemble",
current_iteration.candidates[index].ensemble_spec.name, index)
return index