in adanet/ensemble/weighted.py [0:0]
def _build_weighted_subnetwork_helper(self,
subnetwork,
num_subnetworks,
weight_initializer=None,
key=None,
index=None):
"""Returns the logits and weight of the `WeightedSubnetwork` for key."""
# Treat subnetworks as if their weights are frozen, and ensure that
# mixture weight gradients do not propagate through.
last_layer = _lookup_if_dict(subnetwork.last_layer, key)
logits = _lookup_if_dict(subnetwork.logits, key)
weight_shape = None
last_layer_size = last_layer.get_shape().as_list()[-1]
logits_size = logits.get_shape().as_list()[-1]
batch_size = tf.shape(input=last_layer)[0]
if weight_initializer is None:
weight_initializer = self._select_mixture_weight_initializer(
num_subnetworks)
if self._mixture_weight_type == MixtureWeightType.SCALAR:
weight_shape = []
if self._mixture_weight_type == MixtureWeightType.VECTOR:
weight_shape = [logits_size]
if self._mixture_weight_type == MixtureWeightType.MATRIX:
weight_shape = [last_layer_size, logits_size]
with tf_compat.v1.variable_scope(
"logits_{}".format(index) if index else "logits"):
weight = tf_compat.v1.get_variable(
name="mixture_weight",
shape=weight_shape,
initializer=weight_initializer)
if self._mixture_weight_type == MixtureWeightType.MATRIX:
# TODO: Add Unit tests for the ndims == 3 path.
ndims = len(last_layer.get_shape().as_list())
if ndims > 3:
raise NotImplementedError(
"Last Layer with more than 3 dimensions are not supported with "
"matrix mixture weights.")
# This is reshaping [batch_size, timesteps, emb_dim ] to
# [batch_size x timesteps, emb_dim] for matrix multiplication
# and reshaping back.
if ndims == 3:
logging.info("Rank 3 tensors like [batch_size, timesteps, d] are "
"reshaped to rank 2 [ batch_size x timesteps, d] for "
"the weight matrix multiplication, and are reshaped "
"to their original shape afterwards.")
last_layer = tf.reshape(last_layer, [-1, last_layer_size])
logits = tf.matmul(last_layer, weight)
if ndims == 3:
logits = tf.reshape(logits, [batch_size, -1, logits_size])
else:
logits = tf.multiply(logits, weight)
return logits, weight