def loss_and_gradients()

in model/model.py [0:0]


    def loss_and_gradients(self, imp_method):
        """
        Defines task based and surrogate losses and their
        gradients
        Args:

        Returns:
        """
        reg = 0.0
        if imp_method == 'VAN'  or imp_method == 'PNN' or imp_method == 'ER' or 'GEM' in imp_method:
            pass
        elif imp_method == 'EWC' or imp_method == 'M-EWC':
            reg = tf.add_n([tf.reduce_sum(tf.square(w - w_star) * f) for w, w_star, 
                f in zip(self.trainable_vars, self.star_vars, self.normalized_fisher_at_minima_vars)])
        elif imp_method == 'PI':
            reg = tf.add_n([tf.reduce_sum(tf.square(w - w_star) * f) for w, w_star, 
                f in zip(self.trainable_vars, self.star_vars, self.big_omega_vars)])
        elif imp_method == 'MAS':
            reg = tf.add_n([tf.reduce_sum(tf.square(w - w_star) * f) for w, w_star, 
                f in zip(self.trainable_vars, self.star_vars, self.hebbian_score_vars)])
        elif imp_method == 'RWALK':
            reg = tf.add_n([tf.reduce_sum(tf.square(w - w_star) * (f + scr)) for w, w_star, 
                f, scr in zip(self.trainable_vars, self.star_vars, self.normalized_fisher_at_minima_vars, 
                    self.normalized_score_vars)])
     
        """
        # ***** DON't USE THIS WITH MULTI-HEAD SETTING SINCE THIS WILL UPDATE ALL THE WEIGHTS *****
        # If CNN arch, then use the weight decay
        if self.is_ATT_DATASET:
            self.unweighted_entropy += tf.add_n([0.0005 * tf.nn.l2_loss(v) for v in self.trainable_vars if 'weights' in v.name or 'kernel' in v.name])
        """
        
        if imp_method == 'PNN':
            # Compute the gradients of regularized loss
            self.reg_gradients_vars  = []
            for i in range(self.num_tasks):
                self.reg_gradients_vars.append([])
                self.reg_gradients_vars[i] = self.opt.compute_gradients(self.unweighted_entropy[i], var_list=self.trainable_vars[i])
        elif imp_method != 'A-GEM': # For A-GEM we will define the losses and gradients later on
            if imp_method == 'ER' and 'FC-' not in self.network_arch:
                self.reg_loss = tf.add_n([self.unweighted_entropy[i] for i in range(self.num_tasks)])/ self.mem_batch_size
            else:
                # Regularized training loss
                self.reg_loss = tf.squeeze(self.unweighted_entropy + self.synap_stgth * reg)
                # Compute the gradients of the vanilla loss
                self.vanilla_gradients_vars = self.opt.compute_gradients(self.unweighted_entropy, 
                        var_list=self.trainable_vars)
            # Compute the gradients of regularized loss
            self.reg_gradients_vars = self.opt.compute_gradients(self.reg_loss, 
                    var_list=self.trainable_vars)