in ttw/models/landmark_classification.py [0:0]
def forward(self, batch):
batchsize = batch['target'].size(0)
logits = Variable(torch.FloatTensor(batchsize, self.num_classes)).zero_()
if self.textrecog_features:
embeddings = self.embed(batch['textrecog'])
for i in range(batchsize):
if self.pool == 'sum':
logits[i, :] += self.textrecog_linear(embeddings[i, :, :]).sum(dim=0)
else:
logits[i, :] += self.textrecog_linear(embeddings[i, :, :]).max(dim=0)[0]
if self.fasttext_features:
for i in range(batchsize):
if self.pool == 'sum':
logits[i, :] += self.fasttext_linear(batch['fasttext'][i, :, :]).sum(dim=0)
else:
logits[i, :] += self.fasttext_linear(batch['fasttext'][i, :, :]).max(dim=0)[0]
if self.resnet_features:
for i in range(batchsize):
if self.pool == 'sum':
logits[i, :] += self.resnet_linear(batch['resnet'][i, :, :]/10).sum(dim=0)
else:
logits[i, :] += self.resnet_linear(batch['resnet'][i, :, :]/10).max(dim=0)[0]
self.loss.weight = batch['weight'].view(-1).data
out = dict()
batch['target'] = batch['target'].float()
target = batch['target'].view(-1)
out['loss'] = self.loss(logits.view(-1), target)
y_pred = torch.ge(self.sigmoid(logits), 0.5).float().data.numpy()
y_true = batch['target'].data.numpy()
out['f1'] = f1_score(y_true, y_pred, average='weighted')
out['precision'] = precision_score(y_true, y_pred, average='weighted')
out['recall'] = recall_score(y_true, y_pred, average='weighted')
return out