in source/sagemaker/sagemaker_graph_fraud_detection/dgl_fraud_detection/model/pytorch.py [0:0]
def __init__(self,
g,
in_dim,
num_hidden,
num_classes,
num_layers,
heads,
activation,
feat_drop,
attn_drop,
alpha,
residual):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = nn.ModuleList()
self.activation = activation
# input projection (no residual)
self.gat_layers.append(GATConv(
in_dim, num_hidden, heads[0],
feat_drop, attn_drop, alpha, False))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GATConv(
num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, alpha, residual))
# output projection
self.gat_layers.append(GATConv(
num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, alpha, residual))