def _build_weighted_subnetwork_helper()

in adanet/ensemble/weighted.py [0:0]


  def _build_weighted_subnetwork_helper(self,
                                        subnetwork,
                                        num_subnetworks,
                                        weight_initializer=None,
                                        key=None,
                                        index=None):
    """Returns the logits and weight of the `WeightedSubnetwork` for key."""

    # Treat subnetworks as if their weights are frozen, and ensure that
    # mixture weight gradients do not propagate through.
    last_layer = _lookup_if_dict(subnetwork.last_layer, key)
    logits = _lookup_if_dict(subnetwork.logits, key)
    weight_shape = None
    last_layer_size = last_layer.get_shape().as_list()[-1]
    logits_size = logits.get_shape().as_list()[-1]
    batch_size = tf.shape(input=last_layer)[0]

    if weight_initializer is None:
      weight_initializer = self._select_mixture_weight_initializer(
          num_subnetworks)
      if self._mixture_weight_type == MixtureWeightType.SCALAR:
        weight_shape = []
      if self._mixture_weight_type == MixtureWeightType.VECTOR:
        weight_shape = [logits_size]
      if self._mixture_weight_type == MixtureWeightType.MATRIX:
        weight_shape = [last_layer_size, logits_size]

    with tf_compat.v1.variable_scope(
        "logits_{}".format(index) if index else "logits"):
      weight = tf_compat.v1.get_variable(
          name="mixture_weight",
          shape=weight_shape,
          initializer=weight_initializer)
      if self._mixture_weight_type == MixtureWeightType.MATRIX:
        # TODO: Add Unit tests for the ndims == 3 path.
        ndims = len(last_layer.get_shape().as_list())
        if ndims > 3:
          raise NotImplementedError(
              "Last Layer with more than 3 dimensions are not supported with "
              "matrix mixture weights.")
        # This is reshaping [batch_size, timesteps, emb_dim ] to
        # [batch_size x timesteps, emb_dim] for matrix multiplication
        # and reshaping back.
        if ndims == 3:
          logging.info("Rank 3 tensors like [batch_size, timesteps, d]  are "
                       "reshaped to rank 2 [ batch_size x timesteps, d] for "
                       "the weight matrix multiplication, and are reshaped "
                       "to their original shape afterwards.")
          last_layer = tf.reshape(last_layer, [-1, last_layer_size])
        logits = tf.matmul(last_layer, weight)
        if ndims == 3:
          logits = tf.reshape(logits, [batch_size, -1, logits_size])
      else:
        logits = tf.multiply(logits, weight)
    return logits, weight