def get_dataset_info()

in projects/krisp/graphnetwork_module.py [0:0]


    def get_dataset_info(self, config):
        # Load dataset info
        dataset_data = torch.load(mmf_indirect(config.dataset_info_path))

        # Go through and collect symbol names and confs from our pretrained classifiers
        # Hardcoded to the classifiers
        qid2qnode = {}
        qid2imginfo = {}
        for dat in dataset_data:
            # Get qid
            qid = dat["id"]

            # Get q symbols
            q_words = list(dat["symbols_q"])
            qid2qnode[qid] = q_words

            # Get confidences
            in_data = dat["in_names_confs"]
            in_data = [(name, conf, 0) for name, conf in in_data]
            places_data = dat["places_names_confs"]
            places_data = [(name, conf, 1) for name, conf in places_data]
            lvis_data = dat["lvis_names_confs"]
            lvis_data = [(name, conf, 2) for name, conf in lvis_data]
            vg_data = dat["vg_names_confs"]
            vg_data = [(name, conf, 3) for name, conf in vg_data]
            all_image_tuples = in_data + places_data + lvis_data + vg_data

            # Make into dict to start (name -> conf Tensor)
            img_data = {}
            for name, conf, datasetind in all_image_tuples:
                # Check if name has been put in yet
                if name in img_data:
                    # If yes, insert new confidence in the right place
                    # Don't overwrite in same ind unless conf is higher
                    if conf > img_data[name][datasetind].item():
                        img_data[name][datasetind] = conf
                else:
                    # Otherwise, all zeros and add conf to the right index
                    conf_data = torch.zeros(4)
                    conf_data[datasetind] = conf
                    img_data[name] = conf_data

            # Convert dict to tuples list and add to qid dict
            img_data = [(name, img_data[name]) for name in img_data]
            qid2imginfo[qid] = img_data

        # Convert qid2qnode and qid2imginfo to go from qid -> (name, conf)
        # to qid -> (node_idx, conf) and merge q and img info (concat)
        name2node_idx = {}
        idx = 0
        for nodename in self.graph.nodes:
            name2node_idx[nodename] = idx
            idx += 1
        qid2nodeact = {}
        img_class_sz = None
        for qid in qid2qnode:
            # Get words / confs
            q_words = qid2qnode[qid]  # qid -> [qw_1, qw_2, ...]
            # qid -> [(iw_1, conf_c1, conf_c2, ...), ...]
            img_info = qid2imginfo[qid]
            img_words = [x[0] for x in img_info]
            img_confs = [x[1] for x in img_info]

            # Get the node feature size
            if img_class_sz is None:
                # img_class_confs = img_confs[0]
                assert type(img_confs[0]) is torch.Tensor
                img_class_sz = img_confs[0].size(0)

            # We will arrange the node info
            # [q, img_class_1_conf, img_class_2_conf ... w2v]
            # Add to list
            node_info = {}  # node_idx -> torch.Tensor(q, ic1, ic2, ...)
            for word in q_words:
                # Continue if q word is not in the graph
                if word not in name2node_idx:
                    continue

                # Add node info
                node_idx = name2node_idx[word]
                val = torch.zeros(img_class_sz + 1)
                val[0] = 1
                node_info[node_idx] = val

            # Add img info to node info
            for word, img_confs_w in zip(img_words, img_confs):
                # Continue if img word not in graph
                if word not in name2node_idx:
                    continue

                node_idx = name2node_idx[word]
                if node_idx in node_info:
                    # Append class info to existing node info
                    node_info[node_idx][1:].copy_(img_confs_w)
                else:
                    # Just prepend a zero to the img info (not a question word)
                    val = torch.zeros(img_class_sz + 1)
                    val[1:].copy_(img_confs_w)
                    node_info[node_idx] = val

            # Add node info to dict
            # This structure will be used to dynamically create node info
            # during forward pass
            qid2nodeact[qid] = node_info

        # Check the average # of node activations is reasonable
        num_acts_per_qid = np.mean(
            [len(qid2nodeact[qid].keys()) for qid in qid2nodeact]
        )
        print("Average of %f nodes activated per question" % num_acts_per_qid)

        # Return
        return name2node_idx, qid2nodeact, img_class_sz