in model/model.py [0:0]
def loss_and_gradients(self, imp_method):
"""
Defines task based and surrogate losses and their
gradients
Args:
Returns:
"""
reg = 0.0
if imp_method == 'VAN' or imp_method == 'PNN' or imp_method == 'ER' or 'GEM' in imp_method:
pass
elif imp_method == 'EWC' or imp_method == 'M-EWC':
reg = tf.add_n([tf.reduce_sum(tf.square(w - w_star) * f) for w, w_star,
f in zip(self.trainable_vars, self.star_vars, self.normalized_fisher_at_minima_vars)])
elif imp_method == 'PI':
reg = tf.add_n([tf.reduce_sum(tf.square(w - w_star) * f) for w, w_star,
f in zip(self.trainable_vars, self.star_vars, self.big_omega_vars)])
elif imp_method == 'MAS':
reg = tf.add_n([tf.reduce_sum(tf.square(w - w_star) * f) for w, w_star,
f in zip(self.trainable_vars, self.star_vars, self.hebbian_score_vars)])
elif imp_method == 'RWALK':
reg = tf.add_n([tf.reduce_sum(tf.square(w - w_star) * (f + scr)) for w, w_star,
f, scr in zip(self.trainable_vars, self.star_vars, self.normalized_fisher_at_minima_vars,
self.normalized_score_vars)])
"""
# ***** DON't USE THIS WITH MULTI-HEAD SETTING SINCE THIS WILL UPDATE ALL THE WEIGHTS *****
# If CNN arch, then use the weight decay
if self.is_ATT_DATASET:
self.unweighted_entropy += tf.add_n([0.0005 * tf.nn.l2_loss(v) for v in self.trainable_vars if 'weights' in v.name or 'kernel' in v.name])
"""
if imp_method == 'PNN':
# Compute the gradients of regularized loss
self.reg_gradients_vars = []
for i in range(self.num_tasks):
self.reg_gradients_vars.append([])
self.reg_gradients_vars[i] = self.opt.compute_gradients(self.unweighted_entropy[i], var_list=self.trainable_vars[i])
elif imp_method != 'A-GEM': # For A-GEM we will define the losses and gradients later on
if imp_method == 'ER' and 'FC-' not in self.network_arch:
self.reg_loss = tf.add_n([self.unweighted_entropy[i] for i in range(self.num_tasks)])/ self.mem_batch_size
else:
# Regularized training loss
self.reg_loss = tf.squeeze(self.unweighted_entropy + self.synap_stgth * reg)
# Compute the gradients of the vanilla loss
self.vanilla_gradients_vars = self.opt.compute_gradients(self.unweighted_entropy,
var_list=self.trainable_vars)
# Compute the gradients of regularized loss
self.reg_gradients_vars = self.opt.compute_gradients(self.reg_loss,
var_list=self.trainable_vars)