in adanet/ensemble/weighted.py [0:0]
def build_ensemble(self,
subnetworks,
previous_ensemble_subnetworks,
features,
labels,
logits_dimension,
training,
iteration_step,
summary,
previous_ensemble,
previous_iteration_checkpoint=None):
del features, labels, logits_dimension, training, iteration_step # unused
weighted_subnetworks = []
subnetwork_index = 0
num_subnetworks = len(subnetworks)
if previous_ensemble_subnetworks and previous_ensemble:
num_subnetworks += len(previous_ensemble_subnetworks)
for weighted_subnetwork in previous_ensemble.weighted_subnetworks:
if weighted_subnetwork.subnetwork not in previous_ensemble_subnetworks:
# Pruned.
continue
weight_initializer = None
if self._warm_start_mixture_weights:
if isinstance(weighted_subnetwork.subnetwork.last_layer, dict):
weight_initializer = {
key: self._load_variable(weighted_subnetwork.weight[key],
previous_iteration_checkpoint)
for key in sorted(weighted_subnetwork.subnetwork.last_layer)
}
else:
weight_initializer = self._load_variable(
weighted_subnetwork.weight, previous_iteration_checkpoint)
with tf_compat.v1.variable_scope(
"weighted_subnetwork_{}".format(subnetwork_index)):
weighted_subnetworks.append(
self._build_weighted_subnetwork(
weighted_subnetwork.subnetwork,
num_subnetworks,
weight_initializer=weight_initializer))
subnetwork_index += 1
for subnetwork in subnetworks:
with tf_compat.v1.variable_scope(
"weighted_subnetwork_{}".format(subnetwork_index)):
weighted_subnetworks.append(
self._build_weighted_subnetwork(subnetwork, num_subnetworks))
subnetwork_index += 1
if previous_ensemble:
if len(
previous_ensemble.subnetworks) == len(previous_ensemble_subnetworks):
bias = self._create_bias_term(
weighted_subnetworks,
prior=previous_ensemble.bias,
previous_iteration_checkpoint=previous_iteration_checkpoint)
else:
bias = self._create_bias_term(
weighted_subnetworks,
prior=None,
previous_iteration_checkpoint=previous_iteration_checkpoint)
logging.info("Builders using a pruned set of the subnetworks "
"from the previous ensemble, so its ensemble's bias "
"term will not be warm started with the previous "
"ensemble's bias.")
else:
bias = self._create_bias_term(weighted_subnetworks)
logits = self._create_ensemble_logits(weighted_subnetworks, bias, summary)
complexity_regularization = 0
if isinstance(logits, dict):
for key in sorted(logits):
complexity_regularization += self._compute_complexity_regularization(
weighted_subnetworks, summary, key)
else:
complexity_regularization = self._compute_complexity_regularization(
weighted_subnetworks, summary)
return ComplexityRegularized(
weighted_subnetworks=weighted_subnetworks,
bias=bias,
subnetworks=[ws.subnetwork for ws in weighted_subnetworks],
logits=logits,
complexity_regularization=complexity_regularization)