in mmf/datasets/builders/okvqa/dataset.py [0:0]
def format_for_prediction(self, report):
# Check for case of scores coming from graph
reg_vocab_sz = self.answer_processor.get_true_vocab_size()
if report.scores.size(1) > reg_vocab_sz:
# Should actually have the graph_vqa_answer
assert type(self.answer_processor.processor) is GraphVQAAnswerProcessor
# Collapse into one set of confs (i.e. copy graph ones over if conf is greater)
# Again, assumes graph ans is subset of all answers
scores = torch.Tensor(report.scores.shape).copy_(report.scores)
for batch_ind in range(report.scores.size(0)):
for graph_ind, graph_ans in enumerate(
self.answer_processor.graph_vocab
):
# Get graph conf
graph_conf = scores[batch_ind, reg_vocab_sz + graph_ind].item()
# Get non-graph conf
reg_idx = self.answer_processor.answer_vocab.word2idx(graph_ans)
assert (
reg_idx != self.answer_processor.answer_vocab.UNK_INDEX
and reg_idx < reg_vocab_sz
)
reg_conf = scores[batch_ind, reg_idx].item()
# Set to max, zero out graph ind
scores[batch_ind, reg_idx] = max(graph_conf, reg_conf)
scores[batch_ind, reg_vocab_sz + graph_ind] = -float("Inf")
else:
scores = report.scores
# Get top 5 answers and scores
topkscores, topkinds = torch.topk(scores, 5, dim=1)
answers = scores.argmax(dim=1)
predictions = []
answer_space_size = self.answer_processor.get_true_vocab_size()
for idx, question_id in enumerate(report.id):
# Dictionary to append for prediction
pred_dict = {}
pred_dict["question_id"] = question_id.item()
# Get top-k answers
assert (
len(topkscores[idx]) == len(topkinds[idx]) and len(topkscores[idx]) == 5
)
topk_ans_scores = []
for score, aid in zip(topkscores[idx], topkinds[idx]):
score = score.item()
kaid = aid.item()
if kaid >= answer_space_size:
kaid -= answer_space_size
kanswer = report.context_tokens[idx][kaid]
if kanswer == self.context_processor.PAD_TOKEN:
kanswer = "unanswerable"
else:
kanswer = self.answer_processor.idx2word(kaid)
kanswer = kanswer.replace(" 's", "'s")
topk_ans_scores.append((kanswer, score))
pred_dict["topk"] = topk_ans_scores
# Now get regular answer
answer_id = answers[idx].item()
if answer_id >= answer_space_size:
answer_id -= answer_space_size
answer = report.context_tokens[idx][answer_id]
if answer == self.context_processor.PAD_TOKEN:
answer = "unanswerable"
else:
answer = self.answer_processor.idx2word(answer_id)
answer = answer.replace(" 's", "'s")
pred_dict["answer"] = answer
predictions.append(pred_dict)
# Dump the info
info = {}
info["scores"] = report.scores[idx].cpu()
return predictions