in source/sagemaker/sagemaker_graph_fraud_detection/dgl_fraud_detection/model/mxnet.py [0:0]
def __init__(self,
g,
in_dim,
num_hidden,
num_classes,
num_layers,
heads,
activation,
feat_drop,
attn_drop,
alpha,
residual):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = []
self.activation = activation
# input projection (no residual)
self.gat_layers.append(GATConv(
(in_dim, in_dim), num_hidden, heads[0],
feat_drop, attn_drop, alpha, False))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GATConv(
(num_hidden * heads[l-1], num_hidden * heads[l-1]), num_hidden, heads[l],
feat_drop, attn_drop, alpha, residual))
# output projection
self.output_proj = gluon.nn.Dense(num_classes)
for i, layer in enumerate(self.gat_layers):
self.register_child(layer, "gat_layer_{}".format(i))
self.register_child(self.output_proj, "dense_layer")