in experiments/codes/model/gat/sig_edge_gat.py [0:0]
def forward(self, batch):
data = batch.world_graphs
param_name_to_idx = {k: v for v, k in enumerate(self.weight_names)}
assert data.x.size(0) == data.edge_indicator.size(0)
# extract node embeddings
# data.edge_indicator contains 0's for all nodes and value > 0 for each unique relations
x = F.embedding(
data.edge_indicator,
get_param(self.weights, param_name_to_idx, "common_emb"),
)
# edge attribute is None because we are not learning edge types here
edge_attr = None
if data.edge_index.dim() != 2:
import ipdb
ipdb.set_trace()
for nr in range(self.config.model.signature_gat.num_layers - 1):
param_name_dict = self.prepare_param_idx(nr)
x = F.dropout(
x, p=self.config.model.signature_gat.dropout, training=self.training
)
x = self.edgeConvs[nr](
x, data.edge_index, edge_attr, self.weights, param_name_dict
)
x = F.elu(x)
x = F.dropout(
x, p=self.config.model.signature_gat.dropout, training=self.training
)
param_name_dict = self.prepare_param_idx(
self.config.model.signature_gat.num_layers - 1
)
if self.config.model.signature_gat.num_layers > 0:
x = self.edgeConvs[self.config.model.signature_gat.num_layers - 1](
x, data.edge_index, edge_attr, self.weights, param_name_dict
)
# restore x into B x num_node x dim
chunks = torch.split(x, batch.num_edge_nodes, dim=0)
batches = [p.unsqueeze(0) for p in chunks]
# we only have one batch for world graph
batch = batches[0][0]
# sum over edge type nodes
num_class = self.config.model.num_classes
edge_emb = torch.zeros((num_class, batch.size(-1)))
edge_emb = edge_emb.to(self.config.general.device)
for ei_t in data.edge_indicator.unique():
ei = ei_t.item()
if ei == 0:
# node of type "node", skip
continue
# node of type "edge", take
# we subtract 1 here to re-align the vectors (L399 of data.py)
edge_emb[ei - 1] = batch[data.edge_indicator == ei].mean(dim=0)
return edge_emb, batch