in source/sagemaker/sagemaker_graph_fraud_detection/dgl_fraud_detection/model/mxnet.py [0:0]
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
aggregator_type):
super(GraphSAGE, self).__init__()
self.g = g
with self.name_scope():
self.layers = gluon.nn.Sequential()
# input layer
self.layers.add(SAGEConv(in_feats, n_hidden, aggregator_type, feat_drop=dropout, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.add(SAGEConv(n_hidden, n_hidden, aggregator_type, feat_drop=dropout, activation=activation))
# output layer
self.layers.add(gluon.nn.Dense(n_classes))