def train_forward_softmax()

in archs/models.py [0:0]


    def train_forward_softmax(self, x):
        img, attrs, objs = x[0], x[1], x[2]
        neg_attrs, neg_objs = x[4], x[5]
        inv_attrs, comm_attrs = x[6], x[7]

        sampled_attrs = torch.cat((attrs.unsqueeze(1), neg_attrs), 1)
        sampled_objs = torch.cat((objs.unsqueeze(1), neg_objs), 1)
        img_ind = torch.arange(sampled_objs.shape[0]).unsqueeze(1).repeat(
            1, sampled_attrs.shape[1])

        flat_sampled_attrs = sampled_attrs.view(-1)
        flat_sampled_objs = sampled_objs.view(-1)
        flat_img_ind = img_ind.view(-1)
        labels = torch.zeros_like(sampled_attrs[:, 0]).long()

        self.composed_g = self.compose(flat_sampled_attrs, flat_sampled_objs)

        cls_scores, feat = self.comp_network(
            img[flat_img_ind], self.composed_g, return_feat=True)
        pair_scores = cls_scores[:, :1]
        pair_scores = pair_scores.view(*sampled_attrs.shape)

        loss = 0
        loss_cls = F.cross_entropy(pair_scores, labels)
        loss += loss_cls

        loss_obj = torch.FloatTensor([0])
        loss_attr = torch.FloatTensor([0])
        loss_sparse = torch.FloatTensor([0])
        loss_unif = torch.FloatTensor([0])
        loss_aux = torch.FloatTensor([0])

        acc = (pair_scores.argmax(1) == labels).sum().float() / float(
            len(labels))
        all_losses = {}
        all_losses['total_loss'] = loss
        all_losses['main_loss'] = loss_cls
        all_losses['aux_loss'] = loss_aux
        all_losses['obj_loss'] = loss_obj
        all_losses['attr_loss'] = loss_attr
        all_losses['sparse_loss'] = loss_sparse
        all_losses['unif_loss'] = loss_unif

        return loss, all_losses, acc, (pair_scores, feat)