def format_for_prediction()

in mmf/datasets/builders/okvqa/dataset.py [0:0]


    def format_for_prediction(self, report):
        # Check for case of scores coming from graph
        reg_vocab_sz = self.answer_processor.get_true_vocab_size()
        if report.scores.size(1) > reg_vocab_sz:
            # Should actually have the graph_vqa_answer
            assert type(self.answer_processor.processor) is GraphVQAAnswerProcessor

            # Collapse into one set of confs (i.e. copy graph ones over if conf is greater)
            # Again, assumes graph ans is subset of all answers
            scores = torch.Tensor(report.scores.shape).copy_(report.scores)
            for batch_ind in range(report.scores.size(0)):
                for graph_ind, graph_ans in enumerate(
                    self.answer_processor.graph_vocab
                ):
                    # Get graph conf
                    graph_conf = scores[batch_ind, reg_vocab_sz + graph_ind].item()

                    # Get non-graph conf
                    reg_idx = self.answer_processor.answer_vocab.word2idx(graph_ans)
                    assert (
                        reg_idx != self.answer_processor.answer_vocab.UNK_INDEX
                        and reg_idx < reg_vocab_sz
                    )
                    reg_conf = scores[batch_ind, reg_idx].item()

                    # Set to max, zero out graph ind
                    scores[batch_ind, reg_idx] = max(graph_conf, reg_conf)
                    scores[batch_ind, reg_vocab_sz + graph_ind] = -float("Inf")
        else:
            scores = report.scores

        # Get top 5 answers and scores
        topkscores, topkinds = torch.topk(scores, 5, dim=1)

        answers = scores.argmax(dim=1)

        predictions = []
        answer_space_size = self.answer_processor.get_true_vocab_size()

        for idx, question_id in enumerate(report.id):
            # Dictionary to append for prediction
            pred_dict = {}
            pred_dict["question_id"] = question_id.item()

            # Get top-k answers
            assert (
                len(topkscores[idx]) == len(topkinds[idx]) and len(topkscores[idx]) == 5
            )
            topk_ans_scores = []
            for score, aid in zip(topkscores[idx], topkinds[idx]):
                score = score.item()
                kaid = aid.item()

                if kaid >= answer_space_size:
                    kaid -= answer_space_size
                    kanswer = report.context_tokens[idx][kaid]
                    if kanswer == self.context_processor.PAD_TOKEN:
                        kanswer = "unanswerable"
                else:
                    kanswer = self.answer_processor.idx2word(kaid)
                kanswer = kanswer.replace(" 's", "'s")
                topk_ans_scores.append((kanswer, score))
            pred_dict["topk"] = topk_ans_scores

            # Now get regular answer
            answer_id = answers[idx].item()
            if answer_id >= answer_space_size:
                answer_id -= answer_space_size
                answer = report.context_tokens[idx][answer_id]
                if answer == self.context_processor.PAD_TOKEN:
                    answer = "unanswerable"
            else:
                answer = self.answer_processor.idx2word(answer_id)

            answer = answer.replace(" 's", "'s")
            pred_dict["answer"] = answer
            predictions.append(pred_dict)

            # Dump the info
            info = {}
            info["scores"] = report.scores[idx].cpu()

        return predictions