in model/model.py [0:0]
def loss_and_train_ops_for_attr_vector(self, x, y_):
"""
Loss and training operations for the training of joined embedding model
"""
# Define approproate network
if self.network_arch == 'FC-S':
input_dim = int(x.get_shape()[1])
layer_dims = [input_dim, 256, 256, self.total_classes]
self.fc_variables(layer_dims)
logits = self.fc_feedforward(x, self.weights, self.biases)
elif self.network_arch == 'FC-B':
input_dim = int(x.get_shape()[1])
layer_dims = [input_dim, 2000, 2000, self.total_classes]
self.fc_variables(layer_dims)
logits = self.fc_feedforward(x, self.weights, self.biases)
elif self.network_arch == 'CNN':
num_channels = int(x.get_shape()[-1])
self.image_size = int(x.get_shape()[1])
kernels = [3, 3, 3, 3, 3]
depth = [num_channels, 32, 32, 64, 64, 512]
self.conv_variables(kernels, depth)
logits = self.conv_feedforward(x, self.weights, self.biases, apply_dropout=True)
elif self.network_arch == 'VGG':
# VGG-16
phi_x = self.vgg_16_conv_feedforward(x)
elif self.network_arch == 'RESNET-S':
# Standard ResNet-18
kernels = [3, 3, 3, 3, 3]
filters = [20, 20, 40, 80, 160]
strides = [1, 0, 2, 2, 2]
# Get the image features
phi_x = self.resnet18_conv_feedforward(x, kernels, filters, strides)
elif self.network_arch == 'RESNET-B':
# Standard ResNet-18
kernels = [7, 3, 3, 3, 3]
filters = [64, 64, 128, 256, 512]
strides = [2, 0, 2, 2, 2]
# Get the image features
phi_x = self.resnet18_conv_feedforward(x, kernels, filters, strides)
# Get the attributes embedding
attr_embed = self.get_attribute_embedding(self.class_attr) # Does not contain biases yet, Dimension: TOTAL_CLASSES x image_feature_dim
# Add the biases now
last_layer_biases = bias_variable([self.total_classes], name='attr_embed_b')
self.trainable_vars.append(last_layer_biases)
# Now that we have all the trainable variables, initialize the different book keeping variables
# Note: This method has to be called before calculating fisher
# or any other importance measure
self.init_vars()
# Compute the logits for the ZST case
zst_logits = tf.matmul(phi_x, tf.transpose(attr_embed)) + last_layer_biases
# Prune the predictions to only include the classes for which
# the training data is present
if self.imp_method == 'A-GEM':
pruned_zst_logits = []
self.unweighted_entropy = []
for i in range(self.num_tasks):
pruned_zst_logits.append(tf.where(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(zst_logits)[0], 1]), zst_logits, NEG_INF*tf.ones_like(zst_logits)))
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=pruned_zst_logits[i])
adjusted_entropy = tf.reduce_sum(tf.cast(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(y_)[0], 1]), dtype=tf.float32) * y_, axis=1) * cross_entropy
self.unweighted_entropy.append(tf.reduce_sum(adjusted_entropy))
else:
pruned_zst_logits = tf.where(tf.tile(tf.equal(self.output_mask[None,:], 1.0),
[tf.shape(zst_logits)[0], 1]), zst_logits, NEG_INF*tf.ones_like(zst_logits))
self.unweighted_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=pruned_zst_logits))
self.mse = 2.0*tf.nn.l2_loss(pruned_zst_logits) # tf.nn.l2_loss computes sum(T**2)/ 2
# Create operations for loss and gradient calculation
self.loss_and_gradients(self.imp_method)
# Store the current weights before doing a train step
self.get_current_weights()
if 'GEM' not in self.imp_method:
self.train_op()
# Create operations to compute importance depending on the importance methods
if self.imp_method == 'EWC':
self.create_fisher_ops()
elif self.imp_method == 'M-EWC':
self.create_fisher_ops()
self.create_pathint_ops()
self.combined_fisher_pathint_ops()
elif self.imp_method == 'PI':
self.create_pathint_ops()
elif self.imp_method == 'RWALK':
self.create_fisher_ops()
self.create_pathint_ops()
elif self.imp_method == 'MAS':
self.create_hebbian_ops()
elif (self.imp_method == 'A-GEM') or (self.imp_method == 'S-GEM'):
self.create_stochastic_gem_ops()
# Create weight save and store ops
self.weights_store_ops()
# Summary operations for visualization
tf.summary.scalar("triplet_loss", self.unweighted_entropy)
for v in self.trainable_vars:
tf.summary.histogram(v.name.replace(":", "_"), v)
self.merged_summary = tf.summary.merge_all()
# Accuracy measure
if self.imp_method == 'A-GEM' and 'FC-' not in self.network_arch:
self.correct_predictions = []
self.accuracy = []
for i in range(self.num_tasks):
self.correct_predictions.append(tf.equal(tf.argmax(pruned_zst_logits[i], 1), tf.argmax(y_, 1)))
self.accuracy.append(tf.reduce_mean(tf.cast(self.correct_predictions[i], tf.float32)))
else:
self.correct_predictions = tf.equal(tf.argmax(pruned_zst_logits, 1), tf.argmax(y_, 1))
self.accuracy = tf.reduce_mean(tf.cast(self.correct_predictions, tf.float32))