in model/model.py [0:0]
def loss_and_train_ops_for_one_hot_vector(self, x, y_):
"""
Loss and training operations for the training of one-hot vector based classification model
"""
# Define approproate network
if self.network_arch == 'FC-S':
input_dim = int(x.get_shape()[1])
layer_dims = [input_dim, 256, 256, self.total_classes]
if self.imp_method == 'PNN':
self.task_logits = []
self.task_pruned_logits = []
self.unweighted_entropy = []
for i in range(self.num_tasks):
if i == 0:
self.task_logits.append(self.init_fc_column_progNN(layer_dims, x))
self.task_pruned_logits.append(tf.where(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(self.task_logits[i])[0], 1]), self.task_logits[i], NEG_INF*tf.ones_like(self.task_logits[i])))
self.unweighted_entropy.append(tf.squeeze(tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_[i], logits=self.task_pruned_logits[i])))) # mult by mean(y_[i]) puts unwaranted loss to 0
else:
self.task_logits.append(self.extensible_fc_column_progNN(layer_dims, x, i))
self.task_pruned_logits.append(tf.where(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(self.task_logits[i])[0], 1]), self.task_logits[i], NEG_INF*tf.ones_like(self.task_logits[i])))
self.unweighted_entropy.append(tf.squeeze(tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_[i], logits=self.task_pruned_logits[i])))) # mult by mean(y_[i]) puts unwaranted loss to 0
else:
self.fc_variables(layer_dims)
logits = self.fc_feedforward(x, self.weights, self.biases)
elif self.network_arch == 'FC-B':
input_dim = int(x.get_shape()[1])
layer_dims = [input_dim, 2000, 2000, self.total_classes]
self.fc_variables(layer_dims)
logits = self.fc_feedforward(x, self.weights, self.biases)
elif self.network_arch == 'CNN':
num_channels = int(x.get_shape()[-1])
self.image_size = int(x.get_shape()[1])
kernels = [3, 3, 3, 3, 3]
depth = [num_channels, 32, 32, 64, 64, 512]
self.conv_variables(kernels, depth)
logits = self.conv_feedforward(x, self.weights, self.biases, apply_dropout=True)
elif self.network_arch == 'VGG':
# VGG-16
logits = self.vgg_16_conv_feedforward(x)
elif 'RESNET-' in self.network_arch:
if self.network_arch == 'RESNET-S':
# Same resnet-18 as used in GEM paper
kernels = [3, 3, 3, 3, 3]
filters = [20, 20, 40, 80, 160]
strides = [1, 0, 2, 2, 2]
elif self.network_arch == 'RESNET-B':
# Standard ResNet-18
kernels = [7, 3, 3, 3, 3]
filters = [64, 64, 128, 256, 512]
strides = [2, 0, 2, 2, 2]
if self.imp_method == 'PNN':
self.task_logits = []
self.task_pruned_logits = []
self.unweighted_entropy = []
for i in range(self.num_tasks):
if i == 0:
self.task_logits.append(self.init_resent_column_progNN(x, kernels, filters, strides))
else:
self.task_logits.append(self.extensible_resnet_column_progNN(x, kernels, filters, strides, i))
self.task_pruned_logits.append(tf.where(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(self.task_logits[i])[0], 1]), self.task_logits[i], NEG_INF*tf.ones_like(self.task_logits[i])))
self.unweighted_entropy.append(tf.squeeze(tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_[i], logits=self.task_pruned_logits[i]))))
elif self.imp_method == 'A-GEM' or self.imp_method == 'ER':
logits = self.resnet18_conv_feedforward(x, kernels, filters, strides)
self.task_pruned_logits = []
self.unweighted_entropy = []
for i in range(self.num_tasks):
self.task_pruned_logits.append(tf.where(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(logits)[0], 1]), logits, NEG_INF*tf.ones_like(logits)))
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=self.task_pruned_logits[i])
adjusted_entropy = tf.reduce_sum(tf.cast(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(y_)[0], 1]), dtype=tf.float32) * y_, axis=1) * cross_entropy
self.unweighted_entropy.append(tf.reduce_sum(adjusted_entropy)) # We will average it later on
else:
logits = self.resnet18_conv_feedforward(x, kernels, filters, strides)
# Prune the predictions to only include the classes for which
# the training data is present
if (self.imp_method != 'PNN') and ((self.imp_method != 'A-GEM' and self.imp_method != 'ER') or 'FC-' in self.network_arch):
self.pruned_logits = tf.where(tf.tile(tf.equal(self.output_mask[None,:], 1.0), [tf.shape(logits)[0], 1]), logits, NEG_INF*tf.ones_like(logits))
# Create list of variables for storing different measures
# Note: This method has to be called before calculating fisher
# or any other importance measure
self.init_vars()
# Different entropy measures/ loss definitions
if (self.imp_method != 'PNN') and ((self.imp_method != 'A-GEM' and self.imp_method != 'ER') or 'FC-' in self.network_arch):
self.mse = 2.0*tf.nn.l2_loss(self.pruned_logits) # tf.nn.l2_loss computes sum(T**2)/ 2
self.weighted_entropy = tf.reduce_mean(tf.losses.softmax_cross_entropy(y_,
self.pruned_logits, self.sample_weights, reduction=tf.losses.Reduction.NONE))
self.unweighted_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_,
logits=self.pruned_logits))
# Create operations for loss and gradient calculation
self.loss_and_gradients(self.imp_method)
if self.imp_method != 'PNN':
# Store the current weights before doing a train step
self.get_current_weights()
# For GEM variants train ops will be defined later
if 'GEM' not in self.imp_method:
# Define the training operation here as Pathint ops depend on the train ops
self.train_op()
# Create operations to compute importance depending on the importance methods
if self.imp_method == 'EWC':
self.create_fisher_ops()
elif self.imp_method == 'M-EWC':
self.create_fisher_ops()
self.create_pathint_ops()
self.combined_fisher_pathint_ops()
elif self.imp_method == 'PI':
self.create_pathint_ops()
elif self.imp_method == 'RWALK':
self.create_fisher_ops()
self.create_pathint_ops()
elif self.imp_method == 'MAS':
self.create_hebbian_ops()
elif self.imp_method == 'A-GEM' or self.imp_method == 'S-GEM':
self.create_stochastic_gem_ops()
if self.imp_method != 'PNN':
# Create weight save and store ops
self.weights_store_ops()
# Summary operations for visualization
tf.summary.scalar("unweighted_entropy", self.unweighted_entropy)
for v in self.trainable_vars:
tf.summary.histogram(v.name.replace(":", "_"), v)
self.merged_summary = tf.summary.merge_all()
# Accuracy measure
if (self.imp_method == 'PNN') or ((self.imp_method == 'A-GEM' or self.imp_method == 'ER') and 'FC-' not in self.network_arch):
self.correct_predictions = []
self.accuracy = []
for i in range(self.num_tasks):
if self.imp_method == 'PNN':
self.correct_predictions.append(tf.equal(tf.argmax(self.task_pruned_logits[i], 1), tf.argmax(y_[i], 1)))
else:
self.correct_predictions.append(tf.equal(tf.argmax(self.task_pruned_logits[i], 1), tf.argmax(y_, 1)))
self.accuracy.append(tf.reduce_mean(tf.cast(self.correct_predictions[i], tf.float32)))
else:
self.correct_predictions = tf.equal(tf.argmax(self.pruned_logits, 1), tf.argmax(y_, 1))
self.accuracy = tf.reduce_mean(tf.cast(self.correct_predictions, tf.float32))