in archs/models.py [0:0]
def train_forward_softmax(self, x):
img, attrs, objs = x[0], x[1], x[2]
neg_attrs, neg_objs = x[4], x[5]
inv_attrs, comm_attrs = x[6], x[7]
sampled_attrs = torch.cat((attrs.unsqueeze(1), neg_attrs), 1)
sampled_objs = torch.cat((objs.unsqueeze(1), neg_objs), 1)
img_ind = torch.arange(sampled_objs.shape[0]).unsqueeze(1).repeat(
1, sampled_attrs.shape[1])
flat_sampled_attrs = sampled_attrs.view(-1)
flat_sampled_objs = sampled_objs.view(-1)
flat_img_ind = img_ind.view(-1)
labels = torch.zeros_like(sampled_attrs[:, 0]).long()
self.composed_g = self.compose(flat_sampled_attrs, flat_sampled_objs)
cls_scores, feat = self.comp_network(
img[flat_img_ind], self.composed_g, return_feat=True)
pair_scores = cls_scores[:, :1]
pair_scores = pair_scores.view(*sampled_attrs.shape)
loss = 0
loss_cls = F.cross_entropy(pair_scores, labels)
loss += loss_cls
loss_obj = torch.FloatTensor([0])
loss_attr = torch.FloatTensor([0])
loss_sparse = torch.FloatTensor([0])
loss_unif = torch.FloatTensor([0])
loss_aux = torch.FloatTensor([0])
acc = (pair_scores.argmax(1) == labels).sum().float() / float(
len(labels))
all_losses = {}
all_losses['total_loss'] = loss
all_losses['main_loss'] = loss_cls
all_losses['aux_loss'] = loss_aux
all_losses['obj_loss'] = loss_obj
all_losses['attr_loss'] = loss_attr
all_losses['sparse_loss'] = loss_sparse
all_losses['unif_loss'] = loss_unif
return loss, all_losses, acc, (pair_scores, feat)