in mmf/projects/krisp/graphnetwork_module.py [0:0]
def get_dataset_info(self, config):
# Load dataset info
dataset_data = torch.load(mmf_indirect(config.dataset_info_path))
# Go through and collect symbol names and confs from our pretrained classifiers
# Hardcoded to the classifiers
qid2qnode = {}
qid2imginfo = {}
for dat in dataset_data:
# Get qid
qid = dat["id"]
# Get q symbols
q_words = list(dat["symbols_q"])
qid2qnode[qid] = q_words
# Get confidences
in_data = dat["in_names_confs"]
in_data = [(name, conf, 0) for name, conf in in_data]
places_data = dat["places_names_confs"]
places_data = [(name, conf, 1) for name, conf in places_data]
lvis_data = dat["lvis_names_confs"]
lvis_data = [(name, conf, 2) for name, conf in lvis_data]
vg_data = dat["vg_names_confs"]
vg_data = [(name, conf, 3) for name, conf in vg_data]
all_image_tuples = in_data + places_data + lvis_data + vg_data
# Make into dict to start (name -> conf Tensor)
img_data = {}
for name, conf, datasetind in all_image_tuples:
# Check if name has been put in yet
if name in img_data:
# If yes, insert new confidence in the right place
# Don't overwrite in same ind unless conf is higher
if conf > img_data[name][datasetind].item():
img_data[name][datasetind] = conf
else:
# Otherwise, all zeros and add conf to the right index
conf_data = torch.zeros(4)
conf_data[datasetind] = conf
img_data[name] = conf_data
# Convert dict to tuples list and add to qid dict
img_data = [(name, img_data[name]) for name in img_data]
qid2imginfo[qid] = img_data
# Convert qid2qnode and qid2imginfo to go from qid -> (name, conf)
# to qid -> (node_idx, conf) and merge q and img info (concat)
name2node_idx = {}
idx = 0
for nodename in self.graph.nodes:
name2node_idx[nodename] = idx
idx += 1
qid2nodeact = {}
img_class_sz = None
for qid in qid2qnode:
# Get words / confs
q_words = qid2qnode[qid] # qid -> [qw_1, qw_2, ...]
# qid -> [(iw_1, conf_c1, conf_c2, ...), ...]
img_info = qid2imginfo[qid]
img_words = [x[0] for x in img_info]
img_confs = [x[1] for x in img_info]
# Get the node feature size
if img_class_sz is None:
# img_class_confs = img_confs[0]
assert type(img_confs[0]) is torch.Tensor
img_class_sz = img_confs[0].size(0)
# We will arrange the node info
# [q, img_class_1_conf, img_class_2_conf ... w2v]
# Add to list
node_info = {} # node_idx -> torch.Tensor(q, ic1, ic2, ...)
for word in q_words:
# Continue if q word is not in the graph
if word not in name2node_idx:
continue
# Add node info
node_idx = name2node_idx[word]
val = torch.zeros(img_class_sz + 1)
val[0] = 1
node_info[node_idx] = val
# Add img info to node info
for word, img_confs_w in zip(img_words, img_confs):
# Continue if img word not in graph
if word not in name2node_idx:
continue
node_idx = name2node_idx[word]
if node_idx in node_info:
# Append class info to existing node info
node_info[node_idx][1:].copy_(img_confs_w)
else:
# Just prepend a zero to the img info (not a question word)
val = torch.zeros(img_class_sz + 1)
val[1:].copy_(img_confs_w)
node_info[node_idx] = val
# Add node info to dict
# This structure will be used to dynamically create node info
# during forward pass
qid2nodeact[qid] = node_info
# Check the average # of node activations is reasonable
num_acts_per_qid = np.mean(
[len(qid2nodeact[qid].keys()) for qid in qid2nodeact]
)
print("Average of %f nodes activated per question" % num_acts_per_qid)
# Return
return name2node_idx, qid2nodeact, img_class_sz