def __init__()

in projects/krisp/graphnetwork_module.py [0:0]


    def __init__(self, config, config_extra=None):
        super().__init__()
        self.config = config
        if config_extra is None:
            self.config_extra = {}
        else:
            self.config_extra = config_extra

        # Load the input graph
        raw_graph = torch.load(mmf_indirect(config.kg_path))
        self.graph, self.graph_idx, self.edge_index, self.edge_type = make_graph(
            raw_graph, config.prune_culdesacs
        )

        # Get all the useful graph attributes
        self.num_nodes = len(self.graph.nodes)
        assert len(self.graph_idx.nodes) == self.num_nodes
        self.num_edges = len(self.graph.edges)
        assert len(self.graph_idx.edges) == self.num_edges
        assert self.edge_index.shape[1] == self.num_edges
        assert self.edge_type.shape[0] == self.num_edges
        self.num_relations = len(raw_graph["relations2idx"])

        # Get the dataset specific info and relate it to the constructed graph
        (
            self.name2node_idx,
            self.qid2nodeact,
            self.img_class_sz,
        ) = self.get_dataset_info(config)

        # And get the answer related info
        (
            self.index_in_ans,
            self.index_in_node,
            self.graph_answers,
            self.graph_ans_node_idx,
        ) = self.get_answer_info(config)

        # Save graph answers (to be used by data loader)
        torch.save(self.graph_answers, mmf_indirect(config.graph_vocab_file))

        # If features have w2v, initialize it here
        node2vec_filename = mmf_indirect(config.node2vec_filename)
        node_names = list(self.name2node_idx.keys())
        valid_node2vec = False
        if os.path.exists(node2vec_filename):
            with open(node2vec_filename, "rb") as f:
                node2vec, node_names_saved, no_match_nodes = pickle.load(f)

            # Make sure the nodes here are identical (otherwise,
            # when we update graph code, we might have the wrong graph)
            if set(node_names) == set(node_names_saved):
                valid_node2vec = True

        # Generate node2vec if not done already
        if not valid_node2vec:
            node2vec, node_names_dbg, no_match_nodes = prepare_embeddings(
                node_names,
                mmf_indirect(config.embedding_file),
                config.add_w2v_multiword,
            )
            print("Saving synonym2vec to pickle file:", node2vec_filename)
            pickle.dump(
                (node2vec, node_names_dbg, no_match_nodes),
                open(node2vec_filename, "wb"),
            )

        # Get size
        self.w2v_sz = node2vec[list(node2vec.keys())[0]].shape[0]

        # Get node input dim
        self.in_node_dim = 0
        self.q_offest = 0
        self.img_offset = 0
        self.vb_offset = 0
        self.q_enc_offset = 0
        self.w2v_offset = 0

        # Add question (size 1)
        if "question" in config.node_inputs:
            self.q_offset = self.in_node_dim
            self.in_node_dim += 1

        # Add classifiers
        if "classifiers" in config.node_inputs:
            self.img_offset = self.in_node_dim
            self.in_node_dim += self.img_class_sz

        # Add w2v
        if "w2v" in config.node_inputs:
            self.w2v_offset = self.in_node_dim
            self.in_node_dim += self.w2v_sz

        # Doing no w2v as a seperate option to make this code a LOT simpler
        self.use_w2v = config.use_w2v
        if self.use_w2v:
            # Create the base node feature matrix
            # torch.Tensor of size num_nodes x in_node_dim
            # In forward pass, will need to copy this batch_size times and
            # convert to cuda
            self.base_node_features = torch.zeros(self.num_nodes, self.in_node_dim)

            # Copy over w2v
            for node_name in node2vec:
                # Get w2v, convert to torch, then copy over
                w2v = torch.from_numpy(node2vec[node_name])
                node_idx = self.name2node_idx[node_name]
                self.base_node_features[
                    node_idx, self.w2v_offset : self.w2v_offset + self.w2v_sz
                ].copy_(w2v)
        else:
            self.in_node_dim -= self.w2v_sz
            self.base_node_features = torch.zeros(self.num_nodes, self.in_node_dim)

        # Init
        full_node_dim = self.in_node_dim
        special_input_node = False
        special_input_sz = None

        # If feed_special_node, set inputs to graph network
        if (
            "feed_special_node" in self.config_extra
            and self.config_extra["feed_special_node"]
        ):
            assert not self.config_extra["compress_crossmodel"]
            special_input_node = True
            special_input_sz = 0

            # Get input size
            if (
                "feed_vb_to_graph" in self.config_extra
                and self.config_extra["feed_vb_to_graph"]
                and self.config_extra["feed_mode"] == "feed_vb_logit_to_graph"
            ):
                special_input_sz += self.config.num_labels
            if (
                "feed_vb_to_graph" in self.config_extra
                and self.config_extra["feed_vb_to_graph"]
                and self.config_extra["feed_mode"] == "feed_vb_hid_to_graph"
            ):
                special_input_sz += self.config_extra["vb_hid_sz"]
            if (
                "feed_q_to_graph" in self.config_extra
                and self.config_extra["feed_q_to_graph"]
            ):
                special_input_sz += self.config_extra["q_hid_sz"]

        # Otherwise, we feed into every graph node at start
        else:
            # Add vb conf (just the conf)
            if (
                "feed_vb_to_graph" in self.config_extra
                and self.config_extra["feed_vb_to_graph"]
                and self.config_extra["feed_mode"] == "feed_vb_logit_to_graph"
            ):
                assert not self.config_extra["compress_crossmodel"]
                self.vb_offset = self.in_node_dim
                full_node_dim += 1

            # Add vb vector
            if (
                "feed_vb_to_graph" in self.config_extra
                and self.config_extra["feed_vb_to_graph"]
                and self.config_extra["feed_mode"] == "feed_vb_hid_to_graph"
            ):
                self.vb_offset = self.in_node_dim
                if self.config_extra["compress_crossmodel"]:
                    full_node_dim += self.config_extra["crossmodel_compress_dim"]

                    # Make a compress layer (just a linear tranform)
                    self.compress_linear = nn.Linear(
                        self.config_extra["vb_hid_sz"],
                        self.config_extra["crossmodel_compress_dim"],
                    )

                else:
                    full_node_dim += self.config_extra["vb_hid_sz"]

            # Add q vector
            if (
                "feed_q_to_graph" in self.config_extra
                and self.config_extra["feed_q_to_graph"]
            ):
                assert not self.config_extra["compress_crossmodel"]
                self.q_enc_offset = self.in_node_dim
                full_node_dim += self.config_extra["q_hid_sz"]

        # Set noback_vb
        self.noback_vb = self.config_extra["noback_vb"]

        # Convert edge_index and edge_type matrices to torch
        # In forward pass, we repeat this by bs and convert to cuda
        self.edge_index = torch.from_numpy(self.edge_index)
        self.edge_type = torch.from_numpy(self.edge_type)

        # These are the forward pass data inputs to graph network
        # They are None to start until we know the batch size
        self.node_features_forward = None
        self.edge_index_forward = None
        self.edge_type_forward = None

        # Make graph network itself
        self.gn = GraphNetwork(
            config,
            full_node_dim,
            self.num_relations,
            self.num_nodes,
            special_input_node=special_input_node,
            special_input_sz=special_input_sz,
        )

        # Init hidden debug (used for analysis)
        self.graph_hidden_debug = None