in archs/models.py [0:0]
def generate_predictions(self, scores, obj_truth): # (B, #pairs)
def get_pred_from_scores(_scores):
_, pair_pred = _scores.topk(10, dim=1) #sort(1, descending=True)
pair_pred = pair_pred[:, :10].contiguous().view(-1)
attr_pred, obj_pred = self.pairs[pair_pred][:, 0].view(
-1, 10), self.pairs[pair_pred][:, 1].view(-1, 10)
return (attr_pred, obj_pred)
results = {}
# open world setting -- no mask
mask = self.closed_mask.repeat(scores.shape[0], 1)
closed_scores = scores.clone()
closed_scores[1 - mask] = -1e10
results.update({'open': get_pred_from_scores(closed_scores)})
# closed world setting - set the score for all NON test-pairs to -1e10
#results.update({'closed': get_pred_from_scores(closed_scores)})
results.update({'closed': results['open']})
# object_oracle setting - set the score to -1e10 for all pairs where the true object does NOT participate
mask = self.oracle_obj_mask[obj_truth]
oracle_obj_scores = scores.clone()
oracle_obj_scores[1 - mask] = -1e10
results.update({
'object_oracle': get_pred_from_scores(oracle_obj_scores)
})
return results