in experiments/codes/model/gat/sig_edge_gat.py [0:0]
def __init__(self, config, shared_embeddings=None):
super(GatedNodeGatEncoder, self).__init__(config)
# flag to enable one-hot embedding if needed
self.graph_mode = True
self.one_hot = self.config.model.signature_gat.emb_type == "one-hot"
self.edgeConvs = []
# common node & relation embedding
# we keep one node embedding for all nodes, and individual relation embedding for relation nodes
emb = torch.Tensor(
size=(config.model.num_classes + 1, config.model.relation_embedding_dim)
).to(config.general.device)
# rel_emb = torch.Tensor(size=(1, config.model.relation_embedding_dim)).to(config.general.device)
emb.requires_grad = True # config.model.signature_gat.learn_node_and_rel_emb
torch.nn.init.xavier_normal_(emb)
self.add_weight(emb, "common_emb", weight_norm=config.model.weight_norm)
## Add params
for l in range(config.model.signature_gat.num_layers):
in_channels = config.model.relation_embedding_dim
out_channels = config.model.relation_embedding_dim
heads = config.model.signature_gat.num_heads
edge_dim = config.model.relation_embedding_dim
weight = torch.Tensor(size=(in_channels, heads * out_channels)).to(
config.general.device
)
weight.requires_grad = True
self.add_weight(
weight,
"GatedGATConv.{}.weight".format(l),
initializer=glorot,
weight_norm=config.model.weight_norm,
)
att = torch.Tensor(size=(1, heads, 2 * out_channels)).to(
config.general.device
)
att.requires_grad = True
self.add_weight(
att,
"GatedGATConv.{}.att".format(l),
initializer=glorot,
weight_norm=config.model.weight_norm,
)
if l == 0:
# only add the gru weights once
gru_weight_ih = torch.Tensor(size=(out_channels, 3 * out_channels)).to(
config.general.device
)
gru_weight_ih.requires_grad = True
self.add_weight(
gru_weight_ih,
"GatedGATConv.{}.gru_w_ih".format("_all_"),
weight_norm=config.model.weight_norm,
)
gru_weight_hh = torch.Tensor(size=(out_channels, 3 * out_channels)).to(
config.general.device
)
gru_weight_hh.requires_grad = True
self.add_weight(
gru_weight_hh,
"GatedGATConv.{}.gru_w_hh".format("_all_"),
weight_norm=config.model.weight_norm,
)
gru_bias_ih = torch.Tensor(size=(3 * out_channels,)).to(
config.general.device
)
gru_bias_ih.requires_grad = True
self.add_weight(
gru_bias_ih,
"GatedGATConv.{}.gru_b_ih".format("_all_"),
initializer=(uniform, 1),
weight_norm=config.model.weight_norm,
)
gru_bias_hh = torch.Tensor(size=(3 * out_channels,)).to(
config.general.device
)
gru_bias_hh.requires_grad = True
self.add_weight(
gru_bias_hh,
"GatedGATConv.{}.gru_b_hh".format("_all_"),
initializer=(uniform, 1),
weight_norm=config.model.weight_norm,
)
if config.model.signature_gat.bias and config.model.signature_gat.concat:
bias = torch.Tensor(size=(heads * out_channels,)).to(
config.general.device
)
bias.requires_grad = True
self.add_weight(
bias,
"GatedGATConv.{}.bias".format(l),
initializer=(uniform, 1),
weight_norm=config.model.weight_norm,
)
elif (
config.model.signature_gat.bias
and not config.model.signature_gat.concat
):
bias = torch.Tensor(size=(out_channels,)).to(config.general.device)
bias.requires_grad = True
self.add_weight(
bias,
"GatedGATConv.{}.bias".format(l),
initializer=(uniform, 1),
weight_norm=config.model.weight_norm,
)
self.edgeConvs.append(
GatedGatConv(
in_channels,
out_channels,
edge_dim,
heads=heads,
concat=config.model.signature_gat.concat,
negative_slope=config.model.signature_gat.negative_slope,
dropout=config.model.signature_gat.dropout,
bias=config.model.signature_gat.bias,
)
)