in archs/models.py [0:0]
def evaluate_predictions(self, predictions, attr_truth, obj_truth, topk=1):
# put everything on cpu
attr_truth, obj_truth = attr_truth.cpu(), obj_truth.cpu()
pairs = list(
zip(list(attr_truth.cpu().numpy()), list(obj_truth.cpu().numpy())))
seen_ind = torch.LongTensor([
i for i in range(len(attr_truth)) if pairs[i] in self.train_pairs
])
unseen_ind = torch.LongTensor([
i for i in range(len(attr_truth))
if pairs[i] not in self.train_pairs
])
# top 1 pair accuracy
# open world: attribute, object and pair
attr_match = (attr_truth.unsqueeze(1).repeat(
1, topk) == predictions['open'][0][:, :topk])
obj_match = (obj_truth.unsqueeze(1).repeat(
1, topk) == predictions['open'][1][:, :topk])
open_match = (attr_match * obj_match).any(1).float()
attr_match = attr_match.any(1).float()
obj_match = obj_match.any(1).float()
open_seen_match = open_match[seen_ind]
open_unseen_match = open_match[unseen_ind]
# closed world, obj_oracle: pair
closed_match = (attr_truth == predictions['closed'][0][:, 0]).float(
) * (obj_truth == predictions['closed'][1][:, 0]).float()
obj_oracle_match = (
attr_truth == predictions['object_oracle'][0][:, 0]).float() * (
obj_truth == predictions['object_oracle'][1][:, 0]).float()
return attr_match, obj_match, closed_match, open_match, obj_oracle_match, open_seen_match, open_unseen_match