def build_subnetwork_spec()

in adanet/core/ensemble_builder.py [0:0]


  def build_subnetwork_spec(self,
                            name,
                            subnetwork_builder,
                            summary,
                            features,
                            mode,
                            labels=None,
                            previous_ensemble=None,
                            config=None):
    """Builds a `_SubnetworkSpec` from the given `adanet.subnetwork.Builder`.

    Args:
      name: String name of the subnetwork.
      subnetwork_builder: A `adanet.Builder` instance which defines how to train
        the subnetwork and ensemble mixture weights.
      summary: A `_ScopedSummary` instance for recording ensemble summaries.
      features: Input `dict` of `Tensor` objects.
      mode: Estimator's `ModeKeys`.
      labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
        (for multi-head). Can be `None`.
      previous_ensemble: The previous `Ensemble` from iteration t-1. Used for
        creating the subnetwork train_op.
      config: The `tf.estimator.RunConfig` to use this iteration.

    Returns:
      An new `EnsembleSpec` instance with the `Subnetwork` appended.
    """

    old_vars = _get_current_vars()

    with tf_compat.v1.variable_scope("subnetwork_{}".format(name)):
      step = tf_compat.v1.get_variable(
          "step",
          shape=[],
          initializer=tf_compat.v1.zeros_initializer(),
          trainable=False,
          dtype=tf.int64)

      # Convert to tensor so that users cannot mutate it.
      step_tensor = tf.convert_to_tensor(value=step)
      with summary.current_scope():
        summary.scalar("iteration_step/adanet/iteration_step", step_tensor)
      if config:
        subnetwork_config = config.replace(
            model_dir=os.path.join(config.model_dir, "assets", name))
      else:
        subnetwork_config = tf.estimator.RunConfig(
            session_config=tf.compat.v1.ConfigProto(
                gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)))

      build_subnetwork = functools.partial(
          subnetwork_builder.build_subnetwork,
          features=features,
          logits_dimension=self._head.logits_dimension,
          training=mode == tf.estimator.ModeKeys.TRAIN,
          iteration_step=step_tensor,
          summary=summary,
          previous_ensemble=previous_ensemble)
      # Check which args are in the implemented build_subnetwork method
      # signature for backwards compatibility.
      # Calling low level getargs for py_2_and_3 compatibility.
      defined_args = inspect.getargs(
          subnetwork_builder.build_subnetwork.__code__).args
      if "labels" in defined_args:
        build_subnetwork = functools.partial(build_subnetwork, labels=labels)
      if "config" in defined_args:
        build_subnetwork = functools.partial(
            build_subnetwork, config=subnetwork_config)
      subnetwork_scope = tf_compat.v1.get_variable_scope()
      with summary.current_scope(), _monkey_patch_context(
          iteration_step_scope=subnetwork_scope,
          scoped_summary=summary,
          trainable_vars=[]):
        subnetwork = build_subnetwork()

      subnetwork_var_list = _get_current_vars(diffbase=old_vars)["trainable"]

      estimator_spec = _create_estimator_spec(self._head, features, labels,
                                              mode, subnetwork.logits,
                                              self._use_tpu)

      subnetwork_metrics = _SubnetworkMetrics(self._use_tpu)
      if mode == tf.estimator.ModeKeys.EVAL:
        subnetwork_metrics.create_eval_metrics(
            features=features,
            labels=labels,
            estimator_spec=estimator_spec,
            metric_fn=self._metric_fn)

      if mode == tf.estimator.ModeKeys.TRAIN:
        with summary.current_scope():
          summary.scalar("loss", estimator_spec.loss)

      # Create train ops for training subnetworks and ensembles.
      train_op = None
      if mode == tf.estimator.ModeKeys.TRAIN and subnetwork_builder:
        with summary.current_scope(), _monkey_patch_context(
            iteration_step_scope=subnetwork_scope,
            scoped_summary=summary,
            trainable_vars=subnetwork_var_list):
          train_op = _to_train_op_spec(
              subnetwork_builder.build_subnetwork_train_op(
                  subnetwork=subnetwork,
                  loss=estimator_spec.loss,
                  var_list=subnetwork_var_list,
                  labels=labels,
                  iteration_step=step_tensor,
                  summary=summary,
                  previous_ensemble=previous_ensemble))

      new_vars = _get_current_vars(diffbase=old_vars)
      # Sort our dictionary by key to remove non-determinism of variable order.
      new_vars = collections.OrderedDict(sorted(new_vars.items()))
      # Combine all trainable, global and savable variables into a single list.
      subnetwork_variables = sum(new_vars.values(), []) + [step]

    return _SubnetworkSpec(
        name=name,
        subnetwork=subnetwork,
        builder=subnetwork_builder,
        predictions=estimator_spec.predictions,
        variables=subnetwork_variables,
        loss=estimator_spec.loss,
        step=step,
        train_op=train_op,
        eval_metrics=subnetwork_metrics,
        asset_dir=subnetwork_config.model_dir)