def _architecture_ensemble_spec()

in adanet/core/estimator.py [0:0]


  def _architecture_ensemble_spec(self, architecture, iteration_number,
                                  features, mode, labels,
                                  previous_ensemble_spec, config,
                                  previous_iteration, hooks):
    """Returns an `_EnsembleSpec` with the given architecture.

    Creates the ensemble architecture by calling `generate_subnetworks` on
    `self._subnetwork_generator` and only calling `build_subnetwork` on
    `Builders` included in the architecture. Once their ops are created, their
    variables are restored from the checkpoint.

    Args:
      architecture: An `_Architecture` instance.
      iteration_number: Integer current iteration number.
      features: Dictionary of `Tensor` objects keyed by feature name.
      mode: Defines whether this is training, evaluation or prediction. See
        `ModeKeys`.
      labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
        (for multi-head). Can be `None`.
      previous_ensemble_spec: The `_EnsembleSpec` for the previous iteration.
        Will be `None` for the first iteration.
      config: The current `tf.estimator.RunConfig`.
      previous_iteration: The previous `_Iteration`.
      hooks: A list of `tf.estimator.SessionRunHook`s.

    Returns:
      An `EnsembleSpec` instance for the given architecture.

    Raises:
      ValueError: If a subnetwork from `architecture` is not found in the
        generated candidate `Builders` of the specified iteration.
    """

    previous_ensemble = None
    if previous_ensemble_spec:
      previous_ensemble = previous_ensemble_spec.ensemble
    current_iteration = previous_iteration
    for t, names in architecture.subnetworks_grouped_by_iteration:
      if t != iteration_number:
        continue
      previous_ensemble_reports, all_reports = [], []
      if self._report_materializer:
        previous_ensemble_reports, all_reports = (
            self._collate_subnetwork_reports(iteration_number))
      generated_subnetwork_builders = (
          self._call_generate_candidates(
              previous_ensemble=previous_ensemble,
              iteration_number=iteration_number,
              previous_ensemble_reports=previous_ensemble_reports,
              all_reports=all_reports,
              config=config))
      subnetwork_builder_names = {
          b.name: b for b in generated_subnetwork_builders
      }
      rebuild_subnetwork_builders = []
      for name in names:
        if name not in subnetwork_builder_names:
          raise ValueError(
              "Required subnetwork builder is missing for iteration {}: {}"
              .format(iteration_number, name))
        rebuild_subnetwork_builders.append(subnetwork_builder_names[name])
      previous_ensemble_summary = None
      previous_ensemble_subnetwork_builders = None
      if previous_ensemble_spec:
        # Always skip summaries when rebuilding previous architecture,
        # since they are not useful.
        previous_ensemble_summary = self._summary_maker(
            namespace="ensemble",
            scope=previous_ensemble_spec.name,
            skip_summary=True)
        previous_ensemble_subnetwork_builders = (
            previous_ensemble_spec.subnetwork_builders)
      ensemble_candidates = []
      for ensemble_strategy in self._ensemble_strategies:
        ensemble_candidates += ensemble_strategy.generate_ensemble_candidates(
            rebuild_subnetwork_builders, previous_ensemble_subnetwork_builders)
      ensemble_candidate = self._find_ensemble_candidate(
          architecture.ensemble_candidate_name, ensemble_candidates)
      current_iteration = self._iteration_builder.build_iteration(
          base_global_step=architecture.global_step,
          iteration_number=iteration_number,
          ensemble_candidates=[ensemble_candidate],
          subnetwork_builders=rebuild_subnetwork_builders,
          features=features,
          labels=labels,
          mode=mode,
          config=config,
          previous_ensemble_summary=previous_ensemble_summary,
          rebuilding=True,
          rebuilding_ensembler_name=architecture.ensembler_name,
          previous_iteration=current_iteration)
      max_candidates = 2 if previous_ensemble_spec else 1
      assert len(current_iteration.candidates) == max_candidates
      previous_ensemble_spec = current_iteration.candidates[-1].ensemble_spec
      previous_ensemble = previous_ensemble_spec.ensemble
    previous_ensemble_spec.architecture.set_replay_indices(
        architecture.replay_indices)
    return current_iteration