def loss_and_train_ops_for_attr_vector()

in model/model.py [0:0]


    def loss_and_train_ops_for_attr_vector(self, x, y_): 
        """
        Loss and training operations for the training of joined embedding model
        """
        # Define approproate network
        if self.network_arch == 'FC-S':
            input_dim = int(x.get_shape()[1])
            layer_dims = [input_dim, 256, 256, self.total_classes]
            self.fc_variables(layer_dims)
            logits = self.fc_feedforward(x, self.weights, self.biases)

        elif self.network_arch == 'FC-B':
            input_dim = int(x.get_shape()[1])
            layer_dims = [input_dim, 2000, 2000, self.total_classes]
            self.fc_variables(layer_dims)
            logits = self.fc_feedforward(x, self.weights, self.biases)

        elif self.network_arch == 'CNN':
            num_channels = int(x.get_shape()[-1])
            self.image_size = int(x.get_shape()[1])
            kernels = [3, 3, 3, 3, 3]
            depth = [num_channels, 32, 32, 64, 64, 512]
            self.conv_variables(kernels, depth)
            logits = self.conv_feedforward(x, self.weights, self.biases, apply_dropout=True)

        elif self.network_arch == 'VGG':
            # VGG-16
            phi_x = self.vgg_16_conv_feedforward(x)
            
        elif self.network_arch == 'RESNET-S':
            # Standard ResNet-18
            kernels = [3, 3, 3, 3, 3]
            filters = [20, 20, 40, 80, 160]
            strides = [1, 0, 2, 2, 2]
            # Get the image features
            phi_x = self.resnet18_conv_feedforward(x, kernels, filters, strides)

        elif self.network_arch == 'RESNET-B':
            # Standard ResNet-18
            kernels = [7, 3, 3, 3, 3]
            filters = [64, 64, 128, 256, 512]
            strides = [2, 0, 2, 2, 2]
            # Get the image features
            phi_x = self.resnet18_conv_feedforward(x, kernels, filters, strides)

        # Get the attributes embedding
        attr_embed = self.get_attribute_embedding(self.class_attr) # Does not contain biases yet, Dimension: TOTAL_CLASSES x image_feature_dim
        # Add the biases now
        last_layer_biases = bias_variable([self.total_classes], name='attr_embed_b')
        self.trainable_vars.append(last_layer_biases)

        # Now that we have all the trainable variables, initialize the different book keeping variables
        # Note: This method has to be called before calculating fisher 
        # or any other importance measure
        self.init_vars()

        # Compute the logits for the ZST case
        zst_logits = tf.matmul(phi_x, tf.transpose(attr_embed)) + last_layer_biases
        # Prune the predictions to only include the classes for which
        # the training data is present
        if self.imp_method == 'A-GEM':
            pruned_zst_logits = []
            self.unweighted_entropy = []
            for i in range(self.num_tasks):
                pruned_zst_logits.append(tf.where(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(zst_logits)[0], 1]), zst_logits, NEG_INF*tf.ones_like(zst_logits)))
                cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=pruned_zst_logits[i])
                adjusted_entropy = tf.reduce_sum(tf.cast(tf.tile(tf.equal(self.output_mask[i][None,:], 1.0), [tf.shape(y_)[0], 1]), dtype=tf.float32) * y_, axis=1) * cross_entropy
                self.unweighted_entropy.append(tf.reduce_sum(adjusted_entropy))
        else:
            pruned_zst_logits = tf.where(tf.tile(tf.equal(self.output_mask[None,:], 1.0), 
                [tf.shape(zst_logits)[0], 1]), zst_logits, NEG_INF*tf.ones_like(zst_logits))
            self.unweighted_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=pruned_zst_logits))
            self.mse = 2.0*tf.nn.l2_loss(pruned_zst_logits) # tf.nn.l2_loss computes sum(T**2)/ 2

        # Create operations for loss and gradient calculation
        self.loss_and_gradients(self.imp_method)

        # Store the current weights before doing a train step
        self.get_current_weights()

        if 'GEM' not in self.imp_method:
            self.train_op()

        # Create operations to compute importance depending on the importance methods
        if self.imp_method == 'EWC':
            self.create_fisher_ops()
        elif self.imp_method == 'M-EWC':
            self.create_fisher_ops()
            self.create_pathint_ops()
            self.combined_fisher_pathint_ops()
        elif self.imp_method == 'PI':
            self.create_pathint_ops()
        elif self.imp_method == 'RWALK':
            self.create_fisher_ops()
            self.create_pathint_ops()
        elif self.imp_method == 'MAS':
            self.create_hebbian_ops()
        elif (self.imp_method == 'A-GEM') or (self.imp_method == 'S-GEM'):
            self.create_stochastic_gem_ops()

        # Create weight save and store ops
        self.weights_store_ops()

        # Summary operations for visualization
        tf.summary.scalar("triplet_loss", self.unweighted_entropy)
        for v in self.trainable_vars:
            tf.summary.histogram(v.name.replace(":", "_"), v)
        self.merged_summary = tf.summary.merge_all()

        # Accuracy measure
        if self.imp_method == 'A-GEM' and 'FC-' not in self.network_arch:
            self.correct_predictions = []
            self.accuracy = []
            for i in range(self.num_tasks):
                self.correct_predictions.append(tf.equal(tf.argmax(pruned_zst_logits[i], 1), tf.argmax(y_, 1)))
                self.accuracy.append(tf.reduce_mean(tf.cast(self.correct_predictions[i], tf.float32)))
        else:
            self.correct_predictions = tf.equal(tf.argmax(pruned_zst_logits, 1), tf.argmax(y_, 1))
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_predictions, tf.float32))