in experiments/codes/model/gat/edge_gat.py [0:0]
def __init__(self, config, shared_embeddings=None):
super(GatEncoder, self).__init__(config)
# flag to enable one-hot embedding if needed
self.graph_mode = True
self.one_hot = self.config.model.gat.emb_type == "one-hot"
self.edgeConvs = []
## Add EdgeGATConv params
for l in range(config.model.gat.num_layers):
in_channels = config.model.relation_embedding_dim
out_channels = config.model.relation_embedding_dim
heads = config.model.gat.num_heads
edge_dim = config.model.relation_embedding_dim
weight = torch.Tensor(size=(in_channels, heads * out_channels)).to(
config.general.device
)
weight.requires_grad = True
self.add_weight(
weight,
"EdgeGATConv.{}.weight".format(l),
initializer=glorot,
weight_norm=config.model.weight_norm,
)
att = torch.Tensor(size=(1, heads, 2 * out_channels + edge_dim)).to(
config.general.device
)
att.requires_grad = True
self.add_weight(
att,
"EdgeGATConv.{}.att".format(l),
initializer=glorot,
weight_norm=config.model.weight_norm,
)
edge_update = torch.Tensor(size=(out_channels + edge_dim, out_channels)).to(
config.general.device
)
edge_update.requires_grad = True
self.add_weight(
edge_update,
"EdgeGATConv.{}.edge_update".format(l),
initializer=glorot,
weight_norm=config.model.weight_norm,
)
if config.model.gat.bias and config.model.gat.concat:
bias = torch.Tensor(size=(heads * out_channels,)).to(
config.general.device
)
bias.requires_grad = True
self.add_weight(
bias,
"EdgeGATConv.{}.bias".format(l),
initializer=(uniform, 1),
weight_norm=config.model.weight_norm,
)
elif config.model.gat.bias and not config.model.gat.concat:
bias = torch.Tensor(size=(out_channels,)).to(config.general.device)
bias.requires_grad = True
self.add_weight(
bias,
"EdgeGATConv.{}.bias".format(l),
initializer=(uniform, 1),
weight_norm=config.model.weight_norm,
)
self.edgeConvs.append(
EdgeGatConv(
in_channels,
out_channels,
edge_dim,
heads=heads,
concat=config.model.gat.concat,
negative_slope=config.model.gat.negative_slope,
dropout=config.model.gat.dropout,
bias=config.model.gat.bias,
)
)
## Add classify params
in_class_dim = (
config.model.relation_embedding_dim * 2
+ config.model.relation_embedding_dim
)
self.add_classify_weights(in_class_dim)