def loss_and_train_ops_for_one_hot_vector()

in model/model.py [0:0]


    def loss_and_train_ops_for_one_hot_vector(self, x, y_):
        """
        Loss and training operations for the training of one-hot vector based classification model
        """
        # Define approproate network
        if self.network_arch == 'FC-S':
            input_dim = int(x.get_shape()[1])
            layer_dims = [input_dim, 256, 256, self.total_classes]
            if self.imp_method == 'PNN':
                self.task_logits = []
                self.task_pruned_logits = []
                self.unweighted_entropy = []
                for i in range(self.num_tasks):
                    if i == 0:
                        self.task_logits.append(self.init_fc_column_progNN(layer_dims, x))
                        self.task_pruned_logits.append(tf.where(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(self.task_logits[i])[0], 1]), self.task_logits[i], NEG_INF*tf.ones_like(self.task_logits[i])))
                        self.unweighted_entropy.append(tf.squeeze(tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_[i], logits=self.task_pruned_logits[i])))) # mult by mean(y_[i]) puts unwaranted loss to 0
                    else:
                        self.task_logits.append(self.extensible_fc_column_progNN(layer_dims, x, i))
                        self.task_pruned_logits.append(tf.where(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(self.task_logits[i])[0], 1]), self.task_logits[i], NEG_INF*tf.ones_like(self.task_logits[i])))
                        self.unweighted_entropy.append(tf.squeeze(tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_[i], logits=self.task_pruned_logits[i])))) # mult by mean(y_[i]) puts unwaranted loss to 0
            else:
                self.fc_variables(layer_dims)
                logits = self.fc_feedforward(x, self.weights, self.biases)

        elif self.network_arch == 'FC-B':
            input_dim = int(x.get_shape()[1])
            layer_dims = [input_dim, 2000, 2000, self.total_classes]
            self.fc_variables(layer_dims)
            logits = self.fc_feedforward(x, self.weights, self.biases)

        elif self.network_arch == 'CNN':
            num_channels = int(x.get_shape()[-1])
            self.image_size = int(x.get_shape()[1])
            kernels = [3, 3, 3, 3, 3]
            depth = [num_channels, 32, 32, 64, 64, 512]
            self.conv_variables(kernels, depth)
            logits = self.conv_feedforward(x, self.weights, self.biases, apply_dropout=True)

        elif self.network_arch == 'VGG':
            # VGG-16
            logits = self.vgg_16_conv_feedforward(x)
            
        elif 'RESNET-' in self.network_arch:
            if self.network_arch == 'RESNET-S':
                # Same resnet-18 as used in GEM paper
                kernels = [3, 3, 3, 3, 3]
                filters = [20, 20, 40, 80, 160]
                strides = [1, 0, 2, 2, 2]
            elif self.network_arch == 'RESNET-B':
                # Standard ResNet-18
                kernels = [7, 3, 3, 3, 3]
                filters = [64, 64, 128, 256, 512]
                strides = [2, 0, 2, 2, 2]
            if self.imp_method == 'PNN':
                self.task_logits = []
                self.task_pruned_logits = []
                self.unweighted_entropy = []
                for i in range(self.num_tasks):
                    if i == 0:
                        self.task_logits.append(self.init_resent_column_progNN(x, kernels, filters, strides))
                    else:
                        self.task_logits.append(self.extensible_resnet_column_progNN(x, kernels, filters, strides, i))
                    self.task_pruned_logits.append(tf.where(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(self.task_logits[i])[0], 1]), self.task_logits[i], NEG_INF*tf.ones_like(self.task_logits[i])))
                    self.unweighted_entropy.append(tf.squeeze(tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_[i], logits=self.task_pruned_logits[i]))))
            elif self.imp_method == 'A-GEM' or self.imp_method == 'ER':
                logits = self.resnet18_conv_feedforward(x, kernels, filters, strides)
                self.task_pruned_logits = []
                self.unweighted_entropy = []
                for i in range(self.num_tasks):
                    self.task_pruned_logits.append(tf.where(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(logits)[0], 1]), logits, NEG_INF*tf.ones_like(logits)))
                    cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=self.task_pruned_logits[i])
                    adjusted_entropy = tf.reduce_sum(tf.cast(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(y_)[0], 1]), dtype=tf.float32) * y_, axis=1) * cross_entropy
                    self.unweighted_entropy.append(tf.reduce_sum(adjusted_entropy)) # We will average it later on
            else:
                logits = self.resnet18_conv_feedforward(x, kernels, filters, strides)

        # Prune the predictions to only include the classes for which
        # the training data is present
        if (self.imp_method != 'PNN') and ((self.imp_method != 'A-GEM' and self.imp_method != 'ER') or 'FC-' in self.network_arch):
            self.pruned_logits = tf.where(tf.tile(tf.equal(self.output_mask[None,:], 1.0), [tf.shape(logits)[0], 1]), logits, NEG_INF*tf.ones_like(logits))

        # Create list of variables for storing different measures
        # Note: This method has to be called before calculating fisher 
        # or any other importance measure
        self.init_vars()

        # Different entropy measures/ loss definitions
        if (self.imp_method != 'PNN') and ((self.imp_method != 'A-GEM' and self.imp_method != 'ER') or 'FC-' in self.network_arch):
            self.mse = 2.0*tf.nn.l2_loss(self.pruned_logits) # tf.nn.l2_loss computes sum(T**2)/ 2
            self.weighted_entropy = tf.reduce_mean(tf.losses.softmax_cross_entropy(y_, 
                self.pruned_logits, self.sample_weights, reduction=tf.losses.Reduction.NONE))
            self.unweighted_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, 
                logits=self.pruned_logits))

        # Create operations for loss and gradient calculation
        self.loss_and_gradients(self.imp_method)

        if self.imp_method != 'PNN':
            # Store the current weights before doing a train step
            self.get_current_weights()

        # For GEM variants train ops will be defined later
        if 'GEM' not in self.imp_method:
            # Define the training operation here as Pathint ops depend on the train ops
            self.train_op()

        # Create operations to compute importance depending on the importance methods
        if self.imp_method == 'EWC':
            self.create_fisher_ops()
        elif self.imp_method == 'M-EWC':
            self.create_fisher_ops()
            self.create_pathint_ops()
            self.combined_fisher_pathint_ops()
        elif self.imp_method == 'PI':
            self.create_pathint_ops()
        elif self.imp_method == 'RWALK':
            self.create_fisher_ops()
            self.create_pathint_ops()
        elif self.imp_method == 'MAS':
            self.create_hebbian_ops()
        elif self.imp_method == 'A-GEM' or self.imp_method == 'S-GEM':
            self.create_stochastic_gem_ops()

        if self.imp_method != 'PNN':
            # Create weight save and store ops
            self.weights_store_ops()

            # Summary operations for visualization
            tf.summary.scalar("unweighted_entropy", self.unweighted_entropy)
            for v in self.trainable_vars:
                tf.summary.histogram(v.name.replace(":", "_"), v)
            self.merged_summary = tf.summary.merge_all()

        # Accuracy measure
        if (self.imp_method == 'PNN') or ((self.imp_method == 'A-GEM' or self.imp_method == 'ER') and 'FC-' not in self.network_arch):
            self.correct_predictions = []
            self.accuracy = []
            for i in range(self.num_tasks):
                if self.imp_method == 'PNN':
                    self.correct_predictions.append(tf.equal(tf.argmax(self.task_pruned_logits[i], 1), tf.argmax(y_[i], 1)))
                else:
                    self.correct_predictions.append(tf.equal(tf.argmax(self.task_pruned_logits[i], 1), tf.argmax(y_, 1)))
                self.accuracy.append(tf.reduce_mean(tf.cast(self.correct_predictions[i], tf.float32)))
        else:
            self.correct_predictions = tf.equal(tf.argmax(self.pruned_logits, 1), tf.argmax(y_, 1))
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_predictions, tf.float32))