in mmf/projects/krisp/graphnetwork_module.py [0:0]
def __init__(self, config, config_extra=None):
super().__init__()
self.config = config
if config_extra is None:
self.config_extra = {}
else:
self.config_extra = config_extra
# Load the input graph
raw_graph = torch.load(mmf_indirect(config.kg_path))
self.graph, self.graph_idx, self.edge_index, self.edge_type = make_graph(
raw_graph, config.prune_culdesacs
)
# Get all the useful graph attributes
self.num_nodes = len(self.graph.nodes)
assert len(self.graph_idx.nodes) == self.num_nodes
self.num_edges = len(self.graph.edges)
assert len(self.graph_idx.edges) == self.num_edges
assert self.edge_index.shape[1] == self.num_edges
assert self.edge_type.shape[0] == self.num_edges
self.num_relations = len(raw_graph["relations2idx"])
# Get the dataset specific info and relate it to the constructed graph
(
self.name2node_idx,
self.qid2nodeact,
self.img_class_sz,
) = self.get_dataset_info(config)
# And get the answer related info
(
self.index_in_ans,
self.index_in_node,
self.graph_answers,
self.graph_ans_node_idx,
) = self.get_answer_info(config)
# Save graph answers (to be used by data loader)
torch.save(self.graph_answers, mmf_indirect(config.graph_vocab_file))
# If features have w2v, initialize it here
node2vec_filename = mmf_indirect(config.node2vec_filename)
node_names = list(self.name2node_idx.keys())
valid_node2vec = False
if os.path.exists(node2vec_filename):
with open(node2vec_filename, "rb") as f:
node2vec, node_names_saved, no_match_nodes = pickle.load(f)
# Make sure the nodes here are identical (otherwise,
# when we update graph code, we might have the wrong graph)
if set(node_names) == set(node_names_saved):
valid_node2vec = True
# Generate node2vec if not done already
if not valid_node2vec:
node2vec, node_names_dbg, no_match_nodes = prepare_embeddings(
node_names,
mmf_indirect(config.embedding_file),
config.add_w2v_multiword,
)
print("Saving synonym2vec to pickle file:", node2vec_filename)
pickle.dump(
(node2vec, node_names_dbg, no_match_nodes),
open(node2vec_filename, "wb"),
)
# Get size
self.w2v_sz = node2vec[list(node2vec.keys())[0]].shape[0]
# Get node input dim
self.in_node_dim = 0
self.q_offest = 0
self.img_offset = 0
self.vb_offset = 0
self.q_enc_offset = 0
self.w2v_offset = 0
# Add question (size 1)
if "question" in config.node_inputs:
self.q_offset = self.in_node_dim
self.in_node_dim += 1
# Add classifiers
if "classifiers" in config.node_inputs:
self.img_offset = self.in_node_dim
self.in_node_dim += self.img_class_sz
# Add w2v
if "w2v" in config.node_inputs:
self.w2v_offset = self.in_node_dim
self.in_node_dim += self.w2v_sz
# Doing no w2v as a seperate option to make this code a LOT simpler
self.use_w2v = config.use_w2v
if self.use_w2v:
# Create the base node feature matrix
# torch.Tensor of size num_nodes x in_node_dim
# In forward pass, will need to copy this batch_size times and
# convert to cuda
self.base_node_features = torch.zeros(self.num_nodes, self.in_node_dim)
# Copy over w2v
for node_name in node2vec:
# Get w2v, convert to torch, then copy over
w2v = torch.from_numpy(node2vec[node_name])
node_idx = self.name2node_idx[node_name]
self.base_node_features[
node_idx, self.w2v_offset : self.w2v_offset + self.w2v_sz
].copy_(w2v)
else:
self.in_node_dim -= self.w2v_sz
self.base_node_features = torch.zeros(self.num_nodes, self.in_node_dim)
# Init
full_node_dim = self.in_node_dim
special_input_node = False
special_input_sz = None
# If feed_special_node, set inputs to graph network
if (
"feed_special_node" in self.config_extra
and self.config_extra["feed_special_node"]
):
assert not self.config_extra["compress_crossmodel"]
special_input_node = True
special_input_sz = 0
# Get input size
if (
"feed_vb_to_graph" in self.config_extra
and self.config_extra["feed_vb_to_graph"]
and self.config_extra["feed_mode"] == "feed_vb_logit_to_graph"
):
special_input_sz += self.config.num_labels
if (
"feed_vb_to_graph" in self.config_extra
and self.config_extra["feed_vb_to_graph"]
and self.config_extra["feed_mode"] == "feed_vb_hid_to_graph"
):
special_input_sz += self.config_extra["vb_hid_sz"]
if (
"feed_q_to_graph" in self.config_extra
and self.config_extra["feed_q_to_graph"]
):
special_input_sz += self.config_extra["q_hid_sz"]
# Otherwise, we feed into every graph node at start
else:
# Add vb conf (just the conf)
if (
"feed_vb_to_graph" in self.config_extra
and self.config_extra["feed_vb_to_graph"]
and self.config_extra["feed_mode"] == "feed_vb_logit_to_graph"
):
assert not self.config_extra["compress_crossmodel"]
self.vb_offset = self.in_node_dim
full_node_dim += 1
# Add vb vector
if (
"feed_vb_to_graph" in self.config_extra
and self.config_extra["feed_vb_to_graph"]
and self.config_extra["feed_mode"] == "feed_vb_hid_to_graph"
):
self.vb_offset = self.in_node_dim
if self.config_extra["compress_crossmodel"]:
full_node_dim += self.config_extra["crossmodel_compress_dim"]
# Make a compress layer (just a linear tranform)
self.compress_linear = nn.Linear(
self.config_extra["vb_hid_sz"],
self.config_extra["crossmodel_compress_dim"],
)
else:
full_node_dim += self.config_extra["vb_hid_sz"]
# Add q vector
if (
"feed_q_to_graph" in self.config_extra
and self.config_extra["feed_q_to_graph"]
):
assert not self.config_extra["compress_crossmodel"]
self.q_enc_offset = self.in_node_dim
full_node_dim += self.config_extra["q_hid_sz"]
# Set noback_vb
self.noback_vb = self.config_extra["noback_vb"]
# Convert edge_index and edge_type matrices to torch
# In forward pass, we repeat this by bs and convert to cuda
self.edge_index = torch.from_numpy(self.edge_index)
self.edge_type = torch.from_numpy(self.edge_type)
# These are the forward pass data inputs to graph network
# They are None to start until we know the batch size
self.node_features_forward = None
self.edge_index_forward = None
self.edge_type_forward = None
# Make graph network itself
self.gn = GraphNetwork(
config,
full_node_dim,
self.num_relations,
self.num_nodes,
special_input_node=special_input_node,
special_input_sz=special_input_sz,
)
# Init hidden debug (used for analysis)
self.graph_hidden_debug = None