in projects/krisp/graphnetwork_module.py [0:0]
def get_answer_info(self, config):
# Get answer info
# Recreates mmf answer_vocab here essentially
answer_vocab = VocabDict(mmf_indirect(config.vocab_file))
assert len(answer_vocab) == config.num_labels
# If we're in okvqa v1.0, need to do this a bit differently
if config.okvqa_v_mode in ["v1.0", "v1.0-121", "v1.0-121-mc"]:
# Load the answer translation file (to go from raw strings to
# stemmed in v1.0 vocab)
tx_data = torch.load(mmf_indirect(config.ans_translation_file))
if config.okvqa_v_mode in ["v1.0-121", "v1.0-121-mc"]:
old_graph_vocab = torch.load(mmf_indirect(config.old_graph_vocab_file))
# Get a list of answer node indices
# Important if we want to index those out to (for instance)
# do node classification on them
index_in_ans = []
index_in_node = []
graph_answers = []
nomatch = []
for ans_str in answer_vocab.word2idx_dict:
# Regular, don't worry about 1-1
if config.okvqa_v_mode == "v1.0":
# Convert it to the most common raw answer and
# see if it's in the graph
if ans_str not in tx_data["v10_2_v11_mc"]:
nomatch.append(ans_str)
continue
# Try most common
if tx_data["v10_2_v11_mc"][ans_str] in self.name2node_idx:
# Get raw answer string
raw_ans = tx_data["v10_2_v11_mc"][ans_str]
else:
# Otherwise try all other options
v11_counts = tx_data["v10_2_v11_count"][ans_str]
sorted_counts = sorted(
v11_counts.items(), key=lambda x: x[1], reverse=True
)
raw_ans = None
for k, _ in sorted_counts:
if k in self.name2node_idx:
raw_ans = k
break
# If still no match, continue
if raw_ans is None:
nomatch.append(ans_str)
continue
# Add ans_str to graph answers
graph_answers.append(ans_str)
# Get the node index
# Use the raw name since that's what matches to nodes
node_idx = self.name2node_idx[raw_ans]
index_in_node.append(node_idx)
# Get the vocab index
ans_idx = answer_vocab.word2idx(ans_str)
index_in_ans.append(ans_idx)
else:
# Convert it to the most common raw answer and see if
# it's in the graph
if ans_str not in tx_data["v10_2_v11_mc"]:
nomatch.append(ans_str)
continue
# Try raw too
if config.okvqa_v_mode == "v1.0-121-mc":
# Try most common
if tx_data["v10_2_raw_mc"][ans_str] in self.name2node_idx:
# Get raw answer string
raw_ans = tx_data["v10_2_raw_mc"][ans_str]
else:
# Otherwise try all other options
v11_counts = tx_data["v10_2_raw_count"][ans_str]
sorted_counts = sorted(
v11_counts.items(), key=lambda x: x[1], reverse=True
)
raw_ans = None
for k, _ in sorted_counts:
if k in self.name2node_idx:
raw_ans = k
break
# If still no match, continue
if raw_ans is None:
nomatch.append(ans_str)
continue
else:
# Try most common
if (
tx_data["v10_2_v11_mc"][ans_str] in self.name2node_idx
and tx_data["v10_2_v11_mc"][ans_str] in old_graph_vocab
):
# Get raw answer string
raw_ans = tx_data["v10_2_v11_mc"][ans_str]
else:
# Otherwise try all other options
v11_counts = tx_data["v10_2_v11_count"][ans_str]
sorted_counts = sorted(
v11_counts.items(), key=lambda x: x[1], reverse=True
)
raw_ans = None
for k, _ in sorted_counts:
if k in self.name2node_idx and k in old_graph_vocab:
raw_ans = k
break
# If still no match, continue
if raw_ans is None:
nomatch.append(ans_str)
continue
# Check 1 to 1
if self.name2node_idx[raw_ans] in index_in_node:
if config.okvqa_v_mode == "v1.0-121-mc":
# Check which is more common
assert len(index_in_node) == len(graph_answers)
assert len(index_in_ans) == len(graph_answers)
idx = index_in_node.index(self.name2node_idx[raw_ans])
node_idx = index_in_node[idx]
old_ans_str = graph_answers[idx]
raw_counts = tx_data["v11_2_raw_count"][raw_ans]
assert ans_str in raw_counts and old_ans_str in raw_counts
assert ans_str != old_ans_str
# If new answer more common, go back and replace everything
if raw_counts[ans_str] > raw_counts[old_ans_str]:
assert node_idx == self.name2node_idx[raw_ans]
graph_answers[idx] = ans_str
ans_idx = answer_vocab.word2idx(ans_str)
index_in_ans[idx] = ans_idx
else:
continue
else:
nomatch.append(ans_str)
continue
else:
# Add ans_str to graph answers
graph_answers.append(ans_str)
# Get the node index
# Use the raw name since that's what matches to nodes
node_idx = self.name2node_idx[raw_ans]
index_in_node.append(node_idx)
# Get the vocab index
ans_idx = answer_vocab.word2idx(ans_str)
index_in_ans.append(ans_idx)
print("%d answers not matches" % len(nomatch))
# Get node indices for alphabetized graph answer too
graph_answers = sorted(graph_answers)
graph_ans_node_idx = []
for ans_str in graph_answers:
# Get node index
node_idx = self.name2node_idx[raw_ans]
graph_ans_node_idx.append(node_idx)
else:
assert config.okvqa_v_mode == "v1.1"
# Get a list of answer node indices
# Important if we want to index those out to (for instance)
# do node classification on them
index_in_ans = []
index_in_node = []
graph_answers = []
for ans_str in answer_vocab.word2idx_dict:
# Check if it's in the graph
if ans_str not in self.name2node_idx:
continue
# Add ans_str to graph answers
graph_answers.append(ans_str)
# Get the node index
node_idx = self.name2node_idx[ans_str]
index_in_node.append(node_idx)
# Get the vocab index
ans_idx = answer_vocab.word2idx(ans_str)
index_in_ans.append(ans_idx)
# Get node indices for alphabetized graph answer too
graph_answers = sorted(graph_answers)
graph_ans_node_idx = []
for ans_str in graph_answers:
# Get node index
node_idx = self.name2node_idx[ans_str]
graph_ans_node_idx.append(node_idx)
# Sanity checks
# Should be same length
assert len(index_in_ans) == len(index_in_node)
# And no repeats
assert len(index_in_ans) == len(set(index_in_ans))
if config.okvqa_v_mode != "v1.0":
assert len(index_in_node) == len(set(index_in_node))
assert len(graph_answers) == len(graph_ans_node_idx)
# Check that the overlap is reasonable
num_ans_in_graph = len(index_in_ans)
print("%d answers in graph" % num_ans_in_graph)
# Convert to tensors now
index_in_ans = torch.LongTensor(index_in_ans)
index_in_node = torch.LongTensor(index_in_node)
graph_ans_node_idx = torch.LongTensor(graph_ans_node_idx)
return index_in_ans, index_in_node, graph_answers, graph_ans_node_idx