def build_ensemble_spec()

in adanet/core/ensemble_builder.py [0:0]


  def build_ensemble_spec(self,
                          name,
                          candidate,
                          ensembler,
                          subnetwork_specs,
                          summary,
                          features,
                          mode,
                          iteration_number,
                          labels,
                          my_ensemble_index,
                          previous_ensemble_spec,
                          previous_iteration_checkpoint):
    """Builds an `_EnsembleSpec` with the given `adanet.ensemble.Candidate`.

    Args:
      name: The string name of the ensemble. Typically the name of the builder
        that returned the given `Subnetwork`.
      candidate: The `adanet.ensemble.Candidate` for this spec.
      ensembler: The :class:`adanet.ensemble.Ensembler` to use to ensemble a
        group of subnetworks.
      subnetwork_specs: Iterable of `_SubnetworkSpecs` for this iteration.
      summary: A `_ScopedSummary` instance for recording ensemble summaries.
      features: Input `dict` of `Tensor` objects.
      mode: Estimator `ModeKeys` indicating training, evaluation, or inference.
      iteration_number: Integer current iteration number.
      labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
        (for multi-head).
      my_ensemble_index: An integer holding the index of the ensemble in the
        candidates list of AdaNet.
      previous_ensemble_spec: Link the rest of the `_EnsembleSpec` from
        iteration t-1. Used for creating the subnetwork train_op.
      previous_iteration_checkpoint: `tf.train.Checkpoint` for iteration t-1.

    Returns:
      An `_EnsembleSpec` instance.
    """

    with tf_compat.v1.variable_scope("ensemble_{}".format(name)):
      step = tf_compat.v1.get_variable(
          "step",
          shape=[],
          initializer=tf_compat.v1.zeros_initializer(),
          trainable=False,
          dtype=tf.int64)
      # Convert to tensor so that users cannot mutate it.
      step_tensor = tf.convert_to_tensor(value=step)
      with summary.current_scope():
        summary.scalar("iteration_step/adanet/iteration_step", step_tensor)
      replay_indices = []
      if previous_ensemble_spec:
        replay_indices = copy.copy(
            previous_ensemble_spec.architecture.replay_indices)
      if my_ensemble_index is not None:
        replay_indices.append(my_ensemble_index)

      architecture = _Architecture(
          candidate.name, ensembler.name, replay_indices=replay_indices)
      previous_subnetworks = []
      previous_subnetwork_specs = []
      subnetwork_builders = []
      previous_ensemble = None
      if previous_ensemble_spec:
        previous_ensemble = previous_ensemble_spec.ensemble
        previous_architecture = previous_ensemble_spec.architecture
        keep_indices = range(len(previous_ensemble.subnetworks))
        if len(candidate.subnetwork_builders) == 1 and previous_ensemble:
          # Prune previous ensemble according to the subnetwork.Builder for
          # backwards compatibility.
          subnetwork_builder = candidate.subnetwork_builders[0]
          prune_previous_ensemble = getattr(subnetwork_builder,
                                            "prune_previous_ensemble", None)
          if callable(prune_previous_ensemble):
            logging.warn(
                "Using an `adanet.subnetwork.Builder#prune_previous_ensemble` "
                "is deprecated. Please use a custom `adanet.ensemble.Strategy` "
                "instead.")
            keep_indices = prune_previous_ensemble(previous_ensemble)
        for i, builder in enumerate(previous_ensemble_spec.subnetwork_builders):
          if i not in keep_indices:
            continue
          if builder not in candidate.previous_ensemble_subnetwork_builders:
            continue
          previous_subnetworks.append(previous_ensemble.subnetworks[i])
          previous_subnetwork_specs.append(
              previous_ensemble_spec.subnetwork_specs[i])
          subnetwork_builders.append(builder)
          architecture.add_subnetwork(*previous_architecture.subnetworks[i])
      for builder in candidate.subnetwork_builders:
        architecture.add_subnetwork(iteration_number, builder.name)
        subnetwork_builders.append(builder)
      subnetwork_spec_map = {s.builder.name: s for s in subnetwork_specs}
      relevant_subnetwork_specs = [
          subnetwork_spec_map[s.name] for s in candidate.subnetwork_builders
      ]
      ensemble_scope = tf_compat.v1.get_variable_scope()

      old_vars = _get_current_vars()

      with summary.current_scope(), _monkey_patch_context(
          iteration_step_scope=ensemble_scope,
          scoped_summary=summary,
          trainable_vars=[]):
        ensemble = ensembler.build_ensemble(
            subnetworks=[s.subnetwork for s in relevant_subnetwork_specs],
            previous_ensemble_subnetworks=previous_subnetworks,
            features=features,
            labels=labels,
            logits_dimension=self._head.logits_dimension,
            training=mode == tf.estimator.ModeKeys.TRAIN,
            iteration_step=step_tensor,
            summary=summary,
            previous_ensemble=previous_ensemble,
            previous_iteration_checkpoint=previous_iteration_checkpoint)

      estimator_spec = _create_estimator_spec(self._head, features, labels,
                                              mode, ensemble.logits,
                                              self._use_tpu)

      ensemble_loss = estimator_spec.loss
      adanet_loss = None
      if mode != tf.estimator.ModeKeys.PREDICT:
        adanet_loss = estimator_spec.loss
        # Add ensembler specific loss
        if isinstance(ensemble, ensemble_lib.ComplexityRegularized):
          adanet_loss += ensemble.complexity_regularization

      predictions = estimator_spec.predictions
      export_outputs = estimator_spec.export_outputs

      if (self._export_subnetwork_logits and
          export_outputs and subnetwork_spec_map):
        first_subnetwork_logits = list(
            subnetwork_spec_map.values())[0].subnetwork.logits
        if isinstance(first_subnetwork_logits, dict):
          for head_name in first_subnetwork_logits.keys():
            subnetwork_logits = {
                subnetwork_name: subnetwork_spec.subnetwork.logits[head_name]
                for subnetwork_name, subnetwork_spec in
                subnetwork_spec_map.items()
            }
            export_outputs.update({
                "{}_{}".format(
                    _EnsembleBuilder._SUBNETWORK_LOGITS_EXPORT_SIGNATURE,
                    head_name):
                    tf.estimator.export.PredictOutput(subnetwork_logits)
            })
        else:
          subnetwork_logits = {
              subnetwork_name: subnetwork_spec.subnetwork.logits for
              subnetwork_name, subnetwork_spec in subnetwork_spec_map.items()
          }
          export_outputs.update({
              _EnsembleBuilder._SUBNETWORK_LOGITS_EXPORT_SIGNATURE:
                  tf.estimator.export.PredictOutput(subnetwork_logits)
          })

      if (self._export_subnetwork_last_layer and export_outputs and
          subnetwork_spec_map and
          list(subnetwork_spec_map.values())[0].subnetwork.last_layer is
          not None):
        first_subnetwork_last_layer = list(
            subnetwork_spec_map.values())[0].subnetwork.last_layer
        if isinstance(first_subnetwork_last_layer, dict):
          for head_name in first_subnetwork_last_layer.keys():
            subnetwork_last_layer = {
                subnetwork_name:
                subnetwork_spec.subnetwork.last_layer[head_name] for
                subnetwork_name, subnetwork_spec in subnetwork_spec_map.items()
            }
            export_outputs.update({
                "{}_{}".format(
                    _EnsembleBuilder._SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE,
                    head_name):
                    tf.estimator.export.PredictOutput(subnetwork_last_layer)
            })
        else:
          subnetwork_last_layer = {
              subnetwork_name: subnetwork_spec.subnetwork.last_layer for
              subnetwork_name, subnetwork_spec in subnetwork_spec_map.items()
          }
          export_outputs.update({
              _EnsembleBuilder._SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE:
                  tf.estimator.export.PredictOutput(subnetwork_last_layer)
          })

      if ensemble.predictions and predictions:
        predictions.update(ensemble.predictions)
      if ensemble.predictions and export_outputs:
        export_outputs.update({
            k: tf.estimator.export.PredictOutput(v)
            for k, v in ensemble.predictions.items()
        })

      ensemble_metrics = _EnsembleMetrics(use_tpu=self._use_tpu)
      if mode == tf.estimator.ModeKeys.EVAL:
        ensemble_metrics.create_eval_metrics(
            features=features,
            labels=labels,
            estimator_spec=estimator_spec,
            metric_fn=self._metric_fn,
            architecture=architecture)

      if mode == tf.estimator.ModeKeys.TRAIN:
        with summary.current_scope():
          summary.scalar("loss", estimator_spec.loss)

      ensemble_trainable_vars = _get_current_vars(
          diffbase=old_vars)["trainable"]
      # Create train ops for training subnetworks and ensembles.
      train_op = None
      if mode == tf.estimator.ModeKeys.TRAIN:
        # Note that these mixture weights are on top of the last_layer of the
        # subnetwork constructed in TRAIN mode, which means that dropout is
        # still applied when the mixture weights are being trained.
        ensemble_scope = tf_compat.v1.get_variable_scope()
        with tf_compat.v1.variable_scope("train_mixture_weights"):
          with summary.current_scope(), _monkey_patch_context(
              iteration_step_scope=ensemble_scope,
              scoped_summary=summary,
              trainable_vars=ensemble_trainable_vars):
            # For backwards compatibility.
            subnetwork_builder = candidate.subnetwork_builders[0]
            old_train_op_fn = getattr(subnetwork_builder,
                                      "build_mixture_weights_train_op", None)
            if callable(old_train_op_fn):
              logging.warn(
                  "The `build_mixture_weights_train_op` method is deprecated. "
                  "Please use the `Ensembler#build_train_op` instead.")
              train_op = _to_train_op_spec(
                  subnetwork_builder.build_mixture_weights_train_op(
                      loss=adanet_loss,
                      var_list=ensemble_trainable_vars,
                      logits=ensemble.logits,
                      labels=labels,
                      iteration_step=step_tensor,
                      summary=summary))
            else:
              train_op = _to_train_op_spec(
                  ensembler.build_train_op(
                      ensemble=ensemble,
                      loss=adanet_loss,
                      var_list=ensemble_trainable_vars,
                      labels=labels,
                      iteration_step=step_tensor,
                      summary=summary,
                      previous_ensemble=previous_ensemble))

      new_vars = _get_current_vars(diffbase=old_vars)
      # Sort our dictionary by key to remove non-determinism of variable order.
      new_vars = collections.OrderedDict(sorted(new_vars.items()))
      # Combine all trainable, global and savable variables into a single list.
      ensemble_variables = sum(new_vars.values(), []) + [step]

    return _EnsembleSpec(
        name=name,
        architecture=architecture,
        subnetwork_builders=subnetwork_builders,
        subnetwork_specs=previous_subnetwork_specs + relevant_subnetwork_specs,
        ensemble=ensemble,
        predictions=predictions,
        step=step,
        variables=ensemble_variables,
        loss=ensemble_loss,
        adanet_loss=adanet_loss,
        train_op=train_op,
        eval_metrics=ensemble_metrics,
        export_outputs=export_outputs)