in source/sagemaker/sagemaker_graph_fraud_detection/dgl_fraud_detection/model/pytorch.py [0:0]
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
aggregator_type):
super(GraphSAGE, self).__init__()
self.g = g
with self.name_scope():
self.layers = nn.Sequential()
# input layer
self.layers.add_module(SAGEConv(in_feats, n_hidden, aggregator_type, feat_drop=dropout, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.add_module(SAGEConv(n_hidden, n_hidden, aggregator_type, feat_drop=dropout, activation=activation))
# output layer
self.layers.add_module(nn.Linear(n_hidden, n_classes))