def forward()

in experiments/codes/model/gat/sig_edge_gat.py [0:0]


    def forward(self, batch):
        data = batch.world_graphs
        param_name_to_idx = {k: v for v, k in enumerate(self.weight_names)}
        assert data.x.size(0) == data.edge_indicator.size(0)
        # extract node embeddings
        # data.edge_indicator contains 0's for all nodes and value > 0 for each unique relations
        x = F.embedding(
            data.edge_indicator,
            get_param(self.weights, param_name_to_idx, "common_emb"),
        )
        # edge attribute is None because we are not learning edge types here
        edge_attr = None
        if data.edge_index.dim() != 2:
            import ipdb

            ipdb.set_trace()
        for nr in range(self.config.model.signature_gat.num_layers - 1):
            param_name_dict = self.prepare_param_idx(nr)
            x = F.dropout(
                x, p=self.config.model.signature_gat.dropout, training=self.training
            )
            x = self.edgeConvs[nr](
                x, data.edge_index, edge_attr, self.weights, param_name_dict
            )
            x = F.elu(x)
        x = F.dropout(
            x, p=self.config.model.signature_gat.dropout, training=self.training
        )
        param_name_dict = self.prepare_param_idx(
            self.config.model.signature_gat.num_layers - 1
        )
        if self.config.model.signature_gat.num_layers > 0:
            x = self.edgeConvs[self.config.model.signature_gat.num_layers - 1](
                x, data.edge_index, edge_attr, self.weights, param_name_dict
            )
        # restore x into B x num_node x dim
        chunks = torch.split(x, batch.num_edge_nodes, dim=0)
        batches = [p.unsqueeze(0) for p in chunks]
        # we only have one batch for world graph
        batch = batches[0][0]
        # sum over edge type nodes
        num_class = self.config.model.num_classes
        edge_emb = torch.zeros((num_class, batch.size(-1)))
        edge_emb = edge_emb.to(self.config.general.device)
        for ei_t in data.edge_indicator.unique():
            ei = ei_t.item()
            if ei == 0:
                # node of type "node", skip
                continue
            # node of type "edge", take
            # we subtract 1 here to re-align the vectors (L399 of data.py)
            edge_emb[ei - 1] = batch[data.edge_indicator == ei].mean(dim=0)
        return edge_emb, batch