in mmf/projects/krisp/graphnetwork_module.py [0:0]
def forward(self, sample_list):
# Get the batch size, qids, and device
qids = sample_list["id"]
batch_size = qids.size(0)
device = qids.device
# First, if this is first forward pass or batch size changed,
# we need to allocate everything
if (
self.node_features_forward is None
or batch_size * self.num_nodes != self.node_features_forward.size(0)
):
# Allocate the data
self.node_features_forward = torch.zeros(
self.num_nodes * batch_size, self.in_node_dim
).to(device)
_, num_edges = self.edge_index.size()
self.edge_index_forward = (
torch.LongTensor(2, num_edges * batch_size).fill_(0).to(device)
)
if self.gn.gcn_type == "RGCN":
self.edge_type_forward = (
torch.LongTensor(num_edges * batch_size).fill_(0).to(device)
)
# Get initial values for data
for batch_ind in range(batch_size):
# Copy base_node_features without modification
self.node_features_forward[
self.num_nodes * batch_ind : self.num_nodes * (batch_ind + 1), :
].copy_(self.base_node_features)
# Copy edge_index, but we add self.num_nodes*batch_ind to every value
# This is equivalent to batch_size independent subgraphs
self.edge_index_forward[
:, batch_ind * num_edges : (batch_ind + 1) * num_edges
].copy_(self.edge_index)
self.edge_index_forward[
:, batch_ind * num_edges : (batch_ind + 1) * num_edges
].add_(batch_ind * self.num_nodes)
# And copy edge_types without modification
if self.gn.gcn_type == "RGCN":
self.edge_type_forward[
batch_ind * num_edges : (batch_ind + 1) * num_edges
].copy_(self.edge_type)
# Zero fill the confidences for node features
assert (
self.w2v_offset is not None
and self.q_offset is not None
and self.img_offset is not None
)
assert self.w2v_offset > 0
self.node_features_forward[:, : self.w2v_offset].zero_()
# If in not using confs mode, just leave these values at zero
if not self.config.use_conf:
pass
elif not self.config.use_q:
assert self.config.use_img
# Fill in the new confidences for this batch based on qid
all_node_idx = []
for batch_ind, qid in enumerate(qids):
# Fill in the activated nodes into node_features
# These always start at zero
node_info = self.qid2nodeact[qid.item()]
for node_idx in node_info:
node_val = node_info[node_idx]
# Zero-out q
node_val[0] = 0
self.node_features_forward[
self.num_nodes * batch_ind + node_idx,
: self.img_offset + self.img_class_sz,
].copy_(node_val)
all_node_idx.append(node_idx)
elif not self.config.use_img:
# Fill in the new confidences for this batch based on qid
all_node_idx = []
for batch_ind, qid in enumerate(qids):
# Fill in the activated nodes into node_features
# These always start at zero
node_info = self.qid2nodeact[qid.item()]
for node_idx in node_info:
node_val = node_info[node_idx]
# Zero-out img
node_val[1] = 0
node_val[2] = 0
node_val[3] = 0
node_val[4] = 0
self.node_features_forward[
self.num_nodes * batch_ind + node_idx,
: self.img_offset + self.img_class_sz,
].copy_(node_val)
all_node_idx.append(node_idx)
elif self.config.use_partial_img:
# Get the index of image we're keeping
# For all confs except partial_img_idx, fill in 0's
assert self.config.partial_img_idx in [0, 1, 2, 3]
# Fill in the new confidences for this batch based on qid
all_node_idx = []
for batch_ind, qid in enumerate(qids):
# Fill in the activated nodes into node_features
# These always start at zero
node_info = self.qid2nodeact[qid.item()]
for node_idx in node_info:
node_val = node_info[node_idx]
# Zero-out img (except for one)
db_count = 0
if self.config.partial_img_idx != 0:
node_val[1] = 0
db_count += 1
if self.config.partial_img_idx != 1:
node_val[2] = 0
db_count += 1
if self.config.partial_img_idx != 2:
node_val[3] = 0
db_count += 1
if self.config.partial_img_idx != 3:
node_val[4] = 0
db_count += 1
assert db_count == 3
self.node_features_forward[
self.num_nodes * batch_ind + node_idx,
: self.img_offset + self.img_class_sz,
].copy_(node_val)
all_node_idx.append(node_idx)
else:
# Fill in the new confidences for this batch based on qid
all_node_idx = []
for batch_ind, qid in enumerate(qids):
# Fill in the activated nodes into node_features
# These always start at zero
node_info = self.qid2nodeact[qid.item()]
for node_idx in node_info:
node_val = node_info[node_idx]
self.node_features_forward[
self.num_nodes * batch_ind + node_idx,
: self.img_offset + self.img_class_sz,
].copy_(node_val)
all_node_idx.append(node_idx)
# If necessary, pass in "output nodes" depending on output calculation
# This for instance tells the gn which nodes to subsample
if self.gn.output_type == "graph_level_ansonly":
output_nodes = self.index_in_node # These are node indices that are answers
elif self.gn.output_type == "graph_level_inputonly":
output_nodes = torch.LongTensor(
all_node_idx
) # These are all non-zero nodes for the question
else:
output_nodes = None
# If we're feeding in special node, need a different forward pass into self.gn
if (
"feed_special_node" in self.config_extra
and self.config_extra["feed_special_node"]
):
# Get special_node_input
# Add vb conf (just the conf)
if (
"feed_vb_to_graph" in self.config_extra
and self.config_extra["feed_vb_to_graph"]
and self.config_extra["feed_mode"] == "feed_vb_logit_to_graph"
):
# Go through answer vocab and copy conf into it
if self.noback_vb:
vb_logits = sample_list["vb_logits"].detach()
else:
vb_logits = sample_list["vb_logits"]
special_node_input = torch.sigmoid(vb_logits)
# Add vb feats
if (
"feed_vb_to_graph" in self.config_extra
and self.config_extra["feed_vb_to_graph"]
and self.config_extra["feed_mode"] == "feed_vb_hid_to_graph"
):
if self.noback_vb:
special_node_input = sample_list["vb_hidden"].detach()
else:
special_node_input = sample_list["vb_hidden"]
# Add q enc feats
if (
"feed_q_to_graph" in self.config_extra
and self.config_extra["feed_q_to_graph"]
):
special_node_input = sample_list["q_encoded"]
# Do actual graph forward pass
if self.gn.gcn_type == "RGCN":
output, spec_out = self.gn(
self.node_features_forward,
self.edge_index_forward,
self.edge_type_forward,
batch_size=batch_size,
output_nodes=output_nodes,
special_node_input=special_node_input,
)
elif self.gn.gcn_type in ["GCN", "SAGE"]:
output, spec_out = self.gn(
self.node_features_forward,
self.edge_index_forward,
batch_size=batch_size,
output_nodes=output_nodes,
special_node_input=special_node_input,
)
# Otherwise, proceed normally
else:
# Build node_forward
# Concat other stuff onto it
node_feats_tmp = self.node_features_forward
# Add other input types
# Add vb conf (just the conf)
if (
"feed_vb_to_graph" in self.config_extra
and self.config_extra["feed_vb_to_graph"]
and self.config_extra["feed_mode"] == "feed_vb_logit_to_graph"
):
assert not self.config_extra["compress_crossmodel"]
# Go through answer vocab and copy conf into it
node_feats_tmp = node_feats_tmp.reshape(
(batch_size, self.num_nodes, -1)
)
if self.noback_vb:
vb_logits = sample_list["vb_logits"].detach()
else:
vb_logits = sample_list["vb_logits"]
vb_confs = torch.sigmoid(vb_logits)
vb_confs_graphindexed = torch.zeros(batch_size, self.num_nodes).to(
device
)
vb_confs_graphindexed[:, self.index_in_node] = vb_confs[
:, self.index_in_ans
]
node_feats_tmp = torch.cat(
[node_feats_tmp, vb_confs_graphindexed.unsqueeze(2)], dim=2
)
node_feats_tmp = node_feats_tmp.reshape(
(batch_size * self.num_nodes, -1)
)
# Add vb feats
if (
"feed_vb_to_graph" in self.config_extra
and self.config_extra["feed_vb_to_graph"]
and self.config_extra["feed_mode"] == "feed_vb_hid_to_graph"
):
node_feats_tmp = node_feats_tmp.reshape(
(batch_size, self.num_nodes, -1)
)
# Optionally compress vb_hidden
if self.noback_vb:
vb_hid = sample_list["vb_hidden"].detach()
else:
vb_hid = sample_list["vb_hidden"]
if self.config_extra["compress_crossmodel"]:
vb_hid = F.relu(self.compress_linear(vb_hid))
node_feats_tmp = torch.cat(
[
node_feats_tmp,
vb_hid.unsqueeze(1).repeat((1, self.num_nodes, 1)),
],
dim=2,
)
node_feats_tmp = node_feats_tmp.reshape(
(batch_size * self.num_nodes, -1)
)
# Add q enc feats
if (
"feed_q_to_graph" in self.config_extra
and self.config_extra["feed_q_to_graph"]
):
assert not self.config_extra["compress_crossmodel"]
node_feats_tmp = node_feats_tmp.reshape(
(batch_size, self.num_nodes, -1)
)
node_feats_tmp = torch.cat(
[
node_feats_tmp,
sample_list["q_encoded"]
.unsqueeze(1)
.repeat((1, self.num_nodes, 1)),
],
dim=2,
)
node_feats_tmp = node_feats_tmp.reshape(
(batch_size * self.num_nodes, -1)
)
# Do actual graph forward pass
if self.gn.gcn_type == "RGCN":
output, spec_out = self.gn(
node_feats_tmp,
self.edge_index_forward,
self.edge_type_forward,
batch_size=batch_size,
output_nodes=output_nodes,
)
elif self.gn.gcn_type in ["GCN", "SAGE"]:
output, spec_out = self.gn(
node_feats_tmp,
self.edge_index_forward,
batch_size=batch_size,
output_nodes=output_nodes,
)
# Do any reindexing we need
if self.config.output_type == "hidden_ans":
# Outputs graph hidden features, but re-indexes them to anser vocab
# Same as graph_prediction, but before final prediction
assert output.size(1) == self.num_nodes
assert output.size(2) == self.config.node_hid_dim
assert output.dim() == 3
# If in graph_analysis mode, save the hidden states here
if self.config_extra["analysis_mode"]:
self.graph_hidden_debug = output
# Reindex to match with self.graph_vocab
if self.config.output_order == "alpha":
output = output[:, self.graph_ans_node_idx, :]
assert output.size(1) == len(self.graph_answers)
else:
assert self.config.output_order == "ans"
# Re-index into answer_vocab
outputs_tmp = torch.zeros(
batch_size, self.config.num_labels, self.config.node_hid_dim
).to(device)
outputs_tmp[:, self.index_in_ans, :] = output[:, self.index_in_node, :]
output = outputs_tmp
elif self.config.output_type in [
"graph_level",
"graph_level_ansonly",
"graph_level_inputonly",
]:
pass
# Do nothing here, fc will happen layer
else:
assert self.config.output_type == "graph_prediction"
# Output is size of graph
assert output.size(1) == self.num_nodes
assert output.dim() == 2
# Re-index
if self.config.output_order == "alpha":
output = output[:, self.graph_ans_node_idx]
assert output.size(1) == len(self.graph_answers)
else:
assert self.config.output_order == "ans"
# Re-index into answer_vocab
logits = (
torch.zeros(batch_size, self.config.num_labels)
.fill_(-1e3)
.to(device)
)
logits[:, self.index_in_ans] = output[:, self.index_in_node]
output = logits
# If we generated a spec_out in graph network, put in sample
# list for other modules to use
if spec_out is not None:
sample_list["graph_special_node_out"] = spec_out
return output