def build_ensemble()

in adanet/ensemble/weighted.py [0:0]


  def build_ensemble(self,
                     subnetworks,
                     previous_ensemble_subnetworks,
                     features,
                     labels,
                     logits_dimension,
                     training,
                     iteration_step,
                     summary,
                     previous_ensemble,
                     previous_iteration_checkpoint=None):
    del features, labels, logits_dimension, training, iteration_step  # unused
    weighted_subnetworks = []
    subnetwork_index = 0
    num_subnetworks = len(subnetworks)

    if previous_ensemble_subnetworks and previous_ensemble:
      num_subnetworks += len(previous_ensemble_subnetworks)
      for weighted_subnetwork in previous_ensemble.weighted_subnetworks:
        if weighted_subnetwork.subnetwork not in previous_ensemble_subnetworks:
          # Pruned.
          continue
        weight_initializer = None
        if self._warm_start_mixture_weights:
          if isinstance(weighted_subnetwork.subnetwork.last_layer, dict):
            weight_initializer = {
                key: self._load_variable(weighted_subnetwork.weight[key],
                                         previous_iteration_checkpoint)
                for key in sorted(weighted_subnetwork.subnetwork.last_layer)
            }
          else:
            weight_initializer = self._load_variable(
                weighted_subnetwork.weight, previous_iteration_checkpoint)
        with tf_compat.v1.variable_scope(
            "weighted_subnetwork_{}".format(subnetwork_index)):
          weighted_subnetworks.append(
              self._build_weighted_subnetwork(
                  weighted_subnetwork.subnetwork,
                  num_subnetworks,
                  weight_initializer=weight_initializer))
        subnetwork_index += 1

    for subnetwork in subnetworks:
      with tf_compat.v1.variable_scope(
          "weighted_subnetwork_{}".format(subnetwork_index)):
        weighted_subnetworks.append(
            self._build_weighted_subnetwork(subnetwork, num_subnetworks))
      subnetwork_index += 1

    if previous_ensemble:
      if len(
          previous_ensemble.subnetworks) == len(previous_ensemble_subnetworks):
        bias = self._create_bias_term(
            weighted_subnetworks,
            prior=previous_ensemble.bias,
            previous_iteration_checkpoint=previous_iteration_checkpoint)
      else:
        bias = self._create_bias_term(
            weighted_subnetworks,
            prior=None,
            previous_iteration_checkpoint=previous_iteration_checkpoint)
        logging.info("Builders using a pruned set of the subnetworks "
                     "from the previous ensemble, so its ensemble's bias "
                     "term will not be warm started with the previous "
                     "ensemble's bias.")
    else:
      bias = self._create_bias_term(weighted_subnetworks)

    logits = self._create_ensemble_logits(weighted_subnetworks, bias, summary)
    complexity_regularization = 0
    if isinstance(logits, dict):
      for key in sorted(logits):
        complexity_regularization += self._compute_complexity_regularization(
            weighted_subnetworks, summary, key)
    else:
      complexity_regularization = self._compute_complexity_regularization(
          weighted_subnetworks, summary)

    return ComplexityRegularized(
        weighted_subnetworks=weighted_subnetworks,
        bias=bias,
        subnetworks=[ws.subnetwork for ws in weighted_subnetworks],
        logits=logits,
        complexity_regularization=complexity_regularization)