def forward()

in mmf/projects/krisp/graphnetwork_module.py [0:0]


    def forward(self, sample_list):
        # Get the batch size, qids, and device
        qids = sample_list["id"]
        batch_size = qids.size(0)
        device = qids.device

        # First, if this is first forward pass or batch size changed,
        # we need to allocate everything
        if (
            self.node_features_forward is None
            or batch_size * self.num_nodes != self.node_features_forward.size(0)
        ):

            # Allocate the data
            self.node_features_forward = torch.zeros(
                self.num_nodes * batch_size, self.in_node_dim
            ).to(device)
            _, num_edges = self.edge_index.size()
            self.edge_index_forward = (
                torch.LongTensor(2, num_edges * batch_size).fill_(0).to(device)
            )
            if self.gn.gcn_type == "RGCN":
                self.edge_type_forward = (
                    torch.LongTensor(num_edges * batch_size).fill_(0).to(device)
                )

            # Get initial values for data
            for batch_ind in range(batch_size):
                # Copy base_node_features without modification
                self.node_features_forward[
                    self.num_nodes * batch_ind : self.num_nodes * (batch_ind + 1), :
                ].copy_(self.base_node_features)

                # Copy edge_index, but we add self.num_nodes*batch_ind to every value
                # This is equivalent to batch_size independent subgraphs
                self.edge_index_forward[
                    :, batch_ind * num_edges : (batch_ind + 1) * num_edges
                ].copy_(self.edge_index)
                self.edge_index_forward[
                    :, batch_ind * num_edges : (batch_ind + 1) * num_edges
                ].add_(batch_ind * self.num_nodes)

                # And copy edge_types without modification
                if self.gn.gcn_type == "RGCN":
                    self.edge_type_forward[
                        batch_ind * num_edges : (batch_ind + 1) * num_edges
                    ].copy_(self.edge_type)

        # Zero fill the confidences for node features
        assert (
            self.w2v_offset is not None
            and self.q_offset is not None
            and self.img_offset is not None
        )
        assert self.w2v_offset > 0
        self.node_features_forward[:, : self.w2v_offset].zero_()

        # If in not using confs mode, just leave these values at zero
        if not self.config.use_conf:
            pass
        elif not self.config.use_q:
            assert self.config.use_img

            # Fill in the new confidences for this batch based on qid
            all_node_idx = []
            for batch_ind, qid in enumerate(qids):
                # Fill in the activated nodes into node_features
                # These always start at zero
                node_info = self.qid2nodeact[qid.item()]
                for node_idx in node_info:
                    node_val = node_info[node_idx]
                    # Zero-out q
                    node_val[0] = 0
                    self.node_features_forward[
                        self.num_nodes * batch_ind + node_idx,
                        : self.img_offset + self.img_class_sz,
                    ].copy_(node_val)
                    all_node_idx.append(node_idx)

        elif not self.config.use_img:
            # Fill in the new confidences for this batch based on qid
            all_node_idx = []
            for batch_ind, qid in enumerate(qids):
                # Fill in the activated nodes into node_features
                # These always start at zero
                node_info = self.qid2nodeact[qid.item()]
                for node_idx in node_info:
                    node_val = node_info[node_idx]

                    # Zero-out img
                    node_val[1] = 0
                    node_val[2] = 0
                    node_val[3] = 0
                    node_val[4] = 0
                    self.node_features_forward[
                        self.num_nodes * batch_ind + node_idx,
                        : self.img_offset + self.img_class_sz,
                    ].copy_(node_val)
                    all_node_idx.append(node_idx)
        elif self.config.use_partial_img:
            # Get the index of image we're keeping
            # For all confs except partial_img_idx, fill in 0's
            assert self.config.partial_img_idx in [0, 1, 2, 3]

            # Fill in the new confidences for this batch based on qid
            all_node_idx = []
            for batch_ind, qid in enumerate(qids):
                # Fill in the activated nodes into node_features
                # These always start at zero
                node_info = self.qid2nodeact[qid.item()]
                for node_idx in node_info:
                    node_val = node_info[node_idx]
                    # Zero-out img (except for one)
                    db_count = 0
                    if self.config.partial_img_idx != 0:
                        node_val[1] = 0
                        db_count += 1
                    if self.config.partial_img_idx != 1:
                        node_val[2] = 0
                        db_count += 1
                    if self.config.partial_img_idx != 2:
                        node_val[3] = 0
                        db_count += 1
                    if self.config.partial_img_idx != 3:
                        node_val[4] = 0
                        db_count += 1
                    assert db_count == 3
                    self.node_features_forward[
                        self.num_nodes * batch_ind + node_idx,
                        : self.img_offset + self.img_class_sz,
                    ].copy_(node_val)
                    all_node_idx.append(node_idx)
        else:
            # Fill in the new confidences for this batch based on qid
            all_node_idx = []
            for batch_ind, qid in enumerate(qids):
                # Fill in the activated nodes into node_features
                # These always start at zero
                node_info = self.qid2nodeact[qid.item()]
                for node_idx in node_info:
                    node_val = node_info[node_idx]
                    self.node_features_forward[
                        self.num_nodes * batch_ind + node_idx,
                        : self.img_offset + self.img_class_sz,
                    ].copy_(node_val)
                    all_node_idx.append(node_idx)

        # If necessary, pass in "output nodes" depending on output calculation
        # This for instance tells the gn which nodes to subsample
        if self.gn.output_type == "graph_level_ansonly":
            output_nodes = self.index_in_node  # These are node indices that are answers
        elif self.gn.output_type == "graph_level_inputonly":
            output_nodes = torch.LongTensor(
                all_node_idx
            )  # These are all non-zero nodes for the question
        else:
            output_nodes = None

        # If we're feeding in special node, need a different forward pass into self.gn
        if (
            "feed_special_node" in self.config_extra
            and self.config_extra["feed_special_node"]
        ):
            # Get special_node_input
            # Add vb conf (just the conf)
            if (
                "feed_vb_to_graph" in self.config_extra
                and self.config_extra["feed_vb_to_graph"]
                and self.config_extra["feed_mode"] == "feed_vb_logit_to_graph"
            ):
                # Go through answer vocab and copy conf into it
                if self.noback_vb:
                    vb_logits = sample_list["vb_logits"].detach()
                else:
                    vb_logits = sample_list["vb_logits"]
                special_node_input = torch.sigmoid(vb_logits)

            # Add vb feats
            if (
                "feed_vb_to_graph" in self.config_extra
                and self.config_extra["feed_vb_to_graph"]
                and self.config_extra["feed_mode"] == "feed_vb_hid_to_graph"
            ):
                if self.noback_vb:
                    special_node_input = sample_list["vb_hidden"].detach()
                else:
                    special_node_input = sample_list["vb_hidden"]

            # Add q enc feats
            if (
                "feed_q_to_graph" in self.config_extra
                and self.config_extra["feed_q_to_graph"]
            ):
                special_node_input = sample_list["q_encoded"]

            # Do actual graph forward pass
            if self.gn.gcn_type == "RGCN":
                output, spec_out = self.gn(
                    self.node_features_forward,
                    self.edge_index_forward,
                    self.edge_type_forward,
                    batch_size=batch_size,
                    output_nodes=output_nodes,
                    special_node_input=special_node_input,
                )
            elif self.gn.gcn_type in ["GCN", "SAGE"]:
                output, spec_out = self.gn(
                    self.node_features_forward,
                    self.edge_index_forward,
                    batch_size=batch_size,
                    output_nodes=output_nodes,
                    special_node_input=special_node_input,
                )

        # Otherwise, proceed normally
        else:
            # Build node_forward
            # Concat other stuff onto it
            node_feats_tmp = self.node_features_forward

            # Add other input types
            # Add vb conf (just the conf)
            if (
                "feed_vb_to_graph" in self.config_extra
                and self.config_extra["feed_vb_to_graph"]
                and self.config_extra["feed_mode"] == "feed_vb_logit_to_graph"
            ):
                assert not self.config_extra["compress_crossmodel"]
                # Go through answer vocab and copy conf into it
                node_feats_tmp = node_feats_tmp.reshape(
                    (batch_size, self.num_nodes, -1)
                )
                if self.noback_vb:
                    vb_logits = sample_list["vb_logits"].detach()
                else:
                    vb_logits = sample_list["vb_logits"]
                vb_confs = torch.sigmoid(vb_logits)
                vb_confs_graphindexed = torch.zeros(batch_size, self.num_nodes).to(
                    device
                )
                vb_confs_graphindexed[:, self.index_in_node] = vb_confs[
                    :, self.index_in_ans
                ]
                node_feats_tmp = torch.cat(
                    [node_feats_tmp, vb_confs_graphindexed.unsqueeze(2)], dim=2
                )
                node_feats_tmp = node_feats_tmp.reshape(
                    (batch_size * self.num_nodes, -1)
                )

            # Add vb feats
            if (
                "feed_vb_to_graph" in self.config_extra
                and self.config_extra["feed_vb_to_graph"]
                and self.config_extra["feed_mode"] == "feed_vb_hid_to_graph"
            ):
                node_feats_tmp = node_feats_tmp.reshape(
                    (batch_size, self.num_nodes, -1)
                )

                # Optionally compress vb_hidden
                if self.noback_vb:
                    vb_hid = sample_list["vb_hidden"].detach()
                else:
                    vb_hid = sample_list["vb_hidden"]
                if self.config_extra["compress_crossmodel"]:
                    vb_hid = F.relu(self.compress_linear(vb_hid))
                node_feats_tmp = torch.cat(
                    [
                        node_feats_tmp,
                        vb_hid.unsqueeze(1).repeat((1, self.num_nodes, 1)),
                    ],
                    dim=2,
                )
                node_feats_tmp = node_feats_tmp.reshape(
                    (batch_size * self.num_nodes, -1)
                )

            # Add q enc feats
            if (
                "feed_q_to_graph" in self.config_extra
                and self.config_extra["feed_q_to_graph"]
            ):
                assert not self.config_extra["compress_crossmodel"]
                node_feats_tmp = node_feats_tmp.reshape(
                    (batch_size, self.num_nodes, -1)
                )
                node_feats_tmp = torch.cat(
                    [
                        node_feats_tmp,
                        sample_list["q_encoded"]
                        .unsqueeze(1)
                        .repeat((1, self.num_nodes, 1)),
                    ],
                    dim=2,
                )
                node_feats_tmp = node_feats_tmp.reshape(
                    (batch_size * self.num_nodes, -1)
                )

            # Do actual graph forward pass
            if self.gn.gcn_type == "RGCN":
                output, spec_out = self.gn(
                    node_feats_tmp,
                    self.edge_index_forward,
                    self.edge_type_forward,
                    batch_size=batch_size,
                    output_nodes=output_nodes,
                )
            elif self.gn.gcn_type in ["GCN", "SAGE"]:
                output, spec_out = self.gn(
                    node_feats_tmp,
                    self.edge_index_forward,
                    batch_size=batch_size,
                    output_nodes=output_nodes,
                )

        # Do any reindexing we need
        if self.config.output_type == "hidden_ans":
            # Outputs graph hidden features, but re-indexes them to anser vocab
            # Same as graph_prediction, but before final prediction
            assert output.size(1) == self.num_nodes
            assert output.size(2) == self.config.node_hid_dim
            assert output.dim() == 3

            # If in graph_analysis mode, save the hidden states here
            if self.config_extra["analysis_mode"]:
                self.graph_hidden_debug = output

            # Reindex to match with self.graph_vocab
            if self.config.output_order == "alpha":
                output = output[:, self.graph_ans_node_idx, :]
                assert output.size(1) == len(self.graph_answers)
            else:
                assert self.config.output_order == "ans"

                # Re-index into answer_vocab
                outputs_tmp = torch.zeros(
                    batch_size, self.config.num_labels, self.config.node_hid_dim
                ).to(device)
                outputs_tmp[:, self.index_in_ans, :] = output[:, self.index_in_node, :]
                output = outputs_tmp

        elif self.config.output_type in [
            "graph_level",
            "graph_level_ansonly",
            "graph_level_inputonly",
        ]:
            pass
            # Do nothing here, fc will happen layer
        else:
            assert self.config.output_type == "graph_prediction"

            # Output is size of graph
            assert output.size(1) == self.num_nodes
            assert output.dim() == 2

            # Re-index
            if self.config.output_order == "alpha":
                output = output[:, self.graph_ans_node_idx]
                assert output.size(1) == len(self.graph_answers)
            else:
                assert self.config.output_order == "ans"

                # Re-index into answer_vocab
                logits = (
                    torch.zeros(batch_size, self.config.num_labels)
                    .fill_(-1e3)
                    .to(device)
                )
                logits[:, self.index_in_ans] = output[:, self.index_in_node]
                output = logits

        # If we generated a spec_out in graph network, put in sample
        # list for other modules to use
        if spec_out is not None:
            sample_list["graph_special_node_out"] = spec_out

        return output