in research/gam/gam/trainer/trainer_classification_gcn.py [0:0]
def _get_agreement_reg_loss(self, data, is_train):
"""Computes the regularization loss coming from the agreement term.
This is calculated using the following idea: we incur a loss for pairs of
samples that should have the same label, but for which the predictions of
the classification model are not equal. The loss incured by each pair is
proportionate to the distance between the two predictions, as well as the
confidence we have that they should agree.
In the case of pairs where both samples are labeled (LL), the agreement
confidence is 1.0. When at least one sample is unlabeled (LU, UU), then we
use the agreement model to estimate this confidence.
Note that for the pairs where a label is available, we can compute this loss
wrt. the actual label, instead of the classifier predictions. However, when
both samples are labeled (LL), for one of them we use the prediction and for
the other the true label -- otherwise there are no gradients to proagate.
Args:
data: A CotrainDataset object.
is_train: A placeholder for a boolean that specifies if this is function
is called as part of model training or inference.
Returns:
The computed agreement loss op.
"""
# Select num_pairs_reg pairs of samples from each category LL, LU, UU.
# for which to do the regularization.
indices_ll_right = tf.placeholder(dtype=tf.int64, shape=(None,))
indices_lu_left = tf.placeholder(dtype=tf.int64, shape=(None,))
indices_lu_right = tf.placeholder(dtype=tf.int64, shape=(None,))
indices_uu_left = tf.placeholder(dtype=tf.int64, shape=(None,))
indices_uu_right = tf.placeholder(dtype=tf.int64, shape=(None,))
# First obtain the features shape from the dataset, and append a batch_size
# dimension to it (i.e., `None` to allow for variable batch size).
features_shape = [None] + list(data.features_shape)
features_ll_right = tf.placeholder(dtype=tf.float32, shape=features_shape)
features_lu_left = tf.placeholder(dtype=tf.float32, shape=features_shape)
features_lu_right = tf.placeholder(dtype=tf.float32, shape=features_shape)
features_uu_left = tf.placeholder(dtype=tf.float32, shape=features_shape)
features_uu_right = tf.placeholder(dtype=tf.float32, shape=features_shape)
labels_ll_left_idx = tf.placeholder(dtype=tf.int64, shape=(None,))
labels_ll_right_idx = tf.placeholder(dtype=tf.int64, shape=(None,))
labels_lu_left_idx = tf.placeholder(dtype=tf.int64, shape=(None,))
labels_ll_left = tf.one_hot(labels_ll_left_idx, data.num_classes)
labels_lu_left = tf.one_hot(labels_lu_left_idx, data.num_classes)
with tf.variable_scope('predictions', reuse=True):
# Obtain predictions for all nodes in the graph.
encoding_all, _, _ = self.model.get_encoding_and_params(
inputs=self.features_op,
is_train=is_train,
support=self.support_op,
num_features_nonzero=self.num_features_nonzero_op,
update_batch_stats=False)
predictions_all, _, _ = self.model.get_predictions_and_params(
encoding=encoding_all,
is_train=is_train,
support=self.support_op,
num_features_nonzero=self.num_features_nonzero_op)
predictions_all = self.model.normalize_predictions(predictions_all)
# Select the nodes of interest.
predictions_ll_right = tf.gather(predictions_all, indices_ll_right)
predictions_lu_right = tf.gather(predictions_all, indices_lu_right)
predictions_uu_left = tf.gather(predictions_all, indices_uu_left)
predictions_uu_right = tf.gather(predictions_all, indices_uu_right)
# Compute Euclidean distance between the label distributions that the
# classification model predicts for the src and tgt of each pair.
# Stop gradients need to be added
# The case where there are no more uu or lu
# edges at the end of training, so the shapes don't match needs fixing.
left = tf.concat((labels_ll_left, labels_lu_left, predictions_uu_left),
axis=0)
right = tf.concat(
(predictions_ll_right, predictions_lu_right, predictions_uu_right),
axis=0)
dists = tf.reduce_sum(tf.square(left - right), axis=-1)
# Estimate a weight for each distance, depending on the predictions
# of the agreement model. For the labeled samples, we can use the actual
# agreement between the labels, no need to estimate.
agreement_ll = tf.cast(
tf.equal(labels_ll_left_idx, labels_ll_right_idx), dtype=tf.float32)
_, agreement_lu, _, _ = self.trainer_agr.create_agreement_prediction(
src_features=features_lu_left,
tgt_features=features_lu_right,
is_train=is_train,
src_indices=indices_lu_left,
tgt_indices=indices_lu_right)
_, agreement_uu, _, _ = self.trainer_agr.create_agreement_prediction(
src_features=features_uu_left,
tgt_features=features_uu_right,
is_train=is_train,
src_indices=indices_uu_left,
tgt_indices=indices_uu_right)
agreement = tf.concat((agreement_ll, agreement_lu, agreement_uu), axis=0)
if self.penalize_neg_agr:
# Since the agreement is predicting scores between [0, 1], anything
# under 0.5 should represent disagreement. Therefore, we want to encourage
# agreement whenever the score is > 0.5, otherwise don't incur any loss.
agreement = tf.nn.relu(agreement - 0.5)
# Create a Tensor containing the weights assigned to each pair in the
# agreement regularization loss, depending on how many samples in the pair
# were labeled. This weight can be either reg_weight_ll, reg_weight_lu,
# or reg_weight_uu.
num_ll = tf.shape(predictions_ll_right)[0]
num_lu = tf.shape(predictions_lu_right)[0]
num_uu = tf.shape(predictions_uu_left)[0]
weights = tf.concat(
(self.reg_weight_ll * tf.ones(num_ll,), self.reg_weight_lu *
tf.ones(num_lu,), self.reg_weight_uu * tf.ones(num_uu,)),
axis=0)
# Scale each distance by its agreement weight and regularzation weight.
loss = tf.reduce_mean(dists * weights * agreement)
self.indices_ll_right = indices_ll_right
self.indices_lu_left = indices_lu_left
self.indices_lu_right = indices_lu_right
self.indices_uu_left = indices_uu_left
self.indices_uu_right = indices_uu_right
self.features_ll_right = features_ll_right
self.features_lu_left = features_lu_left
self.features_lu_right = features_lu_right
self.features_uu_left = features_uu_left
self.features_uu_right = features_uu_right
self.labels_ll_left = labels_ll_left_idx
self.labels_ll_right = labels_ll_right_idx
self.labels_lu_left = labels_lu_left_idx
self.agreement_lu = agreement_lu
return loss