def get_answer_info()

in projects/krisp/graphnetwork_module.py [0:0]


    def get_answer_info(self, config):
        # Get answer info
        # Recreates mmf answer_vocab here essentially
        answer_vocab = VocabDict(mmf_indirect(config.vocab_file))
        assert len(answer_vocab) == config.num_labels

        # If we're in okvqa v1.0, need to do this a bit differently
        if config.okvqa_v_mode in ["v1.0", "v1.0-121", "v1.0-121-mc"]:
            # Load the answer translation file (to go from raw strings to
            # stemmed in v1.0 vocab)
            tx_data = torch.load(mmf_indirect(config.ans_translation_file))
            if config.okvqa_v_mode in ["v1.0-121", "v1.0-121-mc"]:
                old_graph_vocab = torch.load(mmf_indirect(config.old_graph_vocab_file))

            # Get a list of answer node indices
            # Important if we want to index those out to (for instance)
            # do node classification on them
            index_in_ans = []
            index_in_node = []
            graph_answers = []
            nomatch = []
            for ans_str in answer_vocab.word2idx_dict:
                # Regular, don't worry about 1-1
                if config.okvqa_v_mode == "v1.0":

                    # Convert it to the most common raw answer and
                    # see if it's in the graph
                    if ans_str not in tx_data["v10_2_v11_mc"]:
                        nomatch.append(ans_str)
                        continue

                    # Try most common
                    if tx_data["v10_2_v11_mc"][ans_str] in self.name2node_idx:
                        # Get raw answer string
                        raw_ans = tx_data["v10_2_v11_mc"][ans_str]
                    else:
                        # Otherwise try all other options
                        v11_counts = tx_data["v10_2_v11_count"][ans_str]
                        sorted_counts = sorted(
                            v11_counts.items(), key=lambda x: x[1], reverse=True
                        )
                        raw_ans = None
                        for k, _ in sorted_counts:
                            if k in self.name2node_idx:
                                raw_ans = k
                                break

                        # If still no match, continue
                        if raw_ans is None:
                            nomatch.append(ans_str)
                            continue

                    # Add ans_str to graph answers
                    graph_answers.append(ans_str)

                    # Get the node index
                    # Use the raw name since that's what matches to nodes
                    node_idx = self.name2node_idx[raw_ans]
                    index_in_node.append(node_idx)

                    # Get the vocab index
                    ans_idx = answer_vocab.word2idx(ans_str)
                    index_in_ans.append(ans_idx)

                else:
                    # Convert it to the most common raw answer and see if
                    # it's in the graph
                    if ans_str not in tx_data["v10_2_v11_mc"]:
                        nomatch.append(ans_str)
                        continue

                    # Try raw too
                    if config.okvqa_v_mode == "v1.0-121-mc":
                        # Try most common
                        if tx_data["v10_2_raw_mc"][ans_str] in self.name2node_idx:
                            # Get raw answer string
                            raw_ans = tx_data["v10_2_raw_mc"][ans_str]
                        else:
                            # Otherwise try all other options
                            v11_counts = tx_data["v10_2_raw_count"][ans_str]
                            sorted_counts = sorted(
                                v11_counts.items(), key=lambda x: x[1], reverse=True
                            )
                            raw_ans = None
                            for k, _ in sorted_counts:
                                if k in self.name2node_idx:
                                    raw_ans = k
                                    break

                            # If still no match, continue
                            if raw_ans is None:
                                nomatch.append(ans_str)
                                continue
                    else:
                        # Try most common
                        if (
                            tx_data["v10_2_v11_mc"][ans_str] in self.name2node_idx
                            and tx_data["v10_2_v11_mc"][ans_str] in old_graph_vocab
                        ):
                            # Get raw answer string
                            raw_ans = tx_data["v10_2_v11_mc"][ans_str]
                        else:
                            # Otherwise try all other options
                            v11_counts = tx_data["v10_2_v11_count"][ans_str]
                            sorted_counts = sorted(
                                v11_counts.items(), key=lambda x: x[1], reverse=True
                            )
                            raw_ans = None
                            for k, _ in sorted_counts:
                                if k in self.name2node_idx and k in old_graph_vocab:
                                    raw_ans = k
                                    break

                            # If still no match, continue
                            if raw_ans is None:
                                nomatch.append(ans_str)
                                continue

                    # Check 1 to 1
                    if self.name2node_idx[raw_ans] in index_in_node:
                        if config.okvqa_v_mode == "v1.0-121-mc":
                            # Check which is more common
                            assert len(index_in_node) == len(graph_answers)
                            assert len(index_in_ans) == len(graph_answers)
                            idx = index_in_node.index(self.name2node_idx[raw_ans])
                            node_idx = index_in_node[idx]
                            old_ans_str = graph_answers[idx]
                            raw_counts = tx_data["v11_2_raw_count"][raw_ans]
                            assert ans_str in raw_counts and old_ans_str in raw_counts
                            assert ans_str != old_ans_str

                            # If new answer more common, go back and replace everything
                            if raw_counts[ans_str] > raw_counts[old_ans_str]:
                                assert node_idx == self.name2node_idx[raw_ans]
                                graph_answers[idx] = ans_str
                                ans_idx = answer_vocab.word2idx(ans_str)
                                index_in_ans[idx] = ans_idx
                            else:
                                continue
                        else:
                            nomatch.append(ans_str)
                            continue
                    else:
                        # Add ans_str to graph answers
                        graph_answers.append(ans_str)

                        # Get the node index
                        # Use the raw name since that's what matches to nodes
                        node_idx = self.name2node_idx[raw_ans]
                        index_in_node.append(node_idx)

                        # Get the vocab index
                        ans_idx = answer_vocab.word2idx(ans_str)
                        index_in_ans.append(ans_idx)
            print("%d answers not matches" % len(nomatch))

            # Get node indices for alphabetized graph answer too
            graph_answers = sorted(graph_answers)
            graph_ans_node_idx = []
            for ans_str in graph_answers:
                # Get node index
                node_idx = self.name2node_idx[raw_ans]
                graph_ans_node_idx.append(node_idx)
        else:
            assert config.okvqa_v_mode == "v1.1"

            # Get a list of answer node indices
            # Important if we want to index those out to (for instance)
            # do node classification on them
            index_in_ans = []
            index_in_node = []
            graph_answers = []
            for ans_str in answer_vocab.word2idx_dict:
                # Check if it's in the graph
                if ans_str not in self.name2node_idx:
                    continue

                # Add ans_str to graph answers
                graph_answers.append(ans_str)

                # Get the node index
                node_idx = self.name2node_idx[ans_str]
                index_in_node.append(node_idx)

                # Get the vocab index
                ans_idx = answer_vocab.word2idx(ans_str)
                index_in_ans.append(ans_idx)

            # Get node indices for alphabetized graph answer too
            graph_answers = sorted(graph_answers)
            graph_ans_node_idx = []
            for ans_str in graph_answers:
                # Get node index
                node_idx = self.name2node_idx[ans_str]
                graph_ans_node_idx.append(node_idx)

        # Sanity checks
        # Should be same length
        assert len(index_in_ans) == len(index_in_node)
        # And no repeats
        assert len(index_in_ans) == len(set(index_in_ans))
        if config.okvqa_v_mode != "v1.0":
            assert len(index_in_node) == len(set(index_in_node))
        assert len(graph_answers) == len(graph_ans_node_idx)

        # Check that the overlap is reasonable
        num_ans_in_graph = len(index_in_ans)
        print("%d answers in graph" % num_ans_in_graph)

        # Convert to tensors now
        index_in_ans = torch.LongTensor(index_in_ans)
        index_in_node = torch.LongTensor(index_in_node)
        graph_ans_node_idx = torch.LongTensor(graph_ans_node_idx)

        return index_in_ans, index_in_node, graph_answers, graph_ans_node_idx