in source/sagemaker/sagemaker_graph_fraud_detection/dgl_fraud_detection/train_dgl_mxnet_entry_point.py [0:0]
def get_model(g, hyperparams, in_feats, n_classes, ctx, model_dir=None):
if model_dir: # load using saved model state
with open(os.path.join(model_dir, 'model_hyperparams.pkl'), 'rb') as f:
hyperparams = pickle.load(f)
with open(os.path.join(model_dir, 'graph.pkl'), 'rb') as f:
g = pickle.load(f)
if hyperparams['heterogeneous']:
model = HeteroRGCN(g,
in_feats,
hyperparams['n_hidden'],
n_classes,
hyperparams['n_layers'],
hyperparams['embedding_size'],
ctx)
else:
if hyperparams['model'] == 'gcn':
model = GCN(g,
in_feats,
hyperparams['n_hidden'],
n_classes,
hyperparams['n_layers'],
nd.relu,
hyperparams['dropout'])
elif hyperparams['model'] == 'graphsage':
model = GraphSAGE(g,
in_feats,
hyperparams['n_hidden'],
n_classes,
hyperparams['n_layers'],
nd.relu,
hyperparams['dropout'],
hyperparams['aggregator_type'])
else:
heads = ([hyperparams['num_heads']] * hyperparams['n_layers']) + [hyperparams['num_out_heads']]
model = GAT(g,
in_feats,
hyperparams['n_hidden'],
n_classes,
hyperparams['n_layers'],
heads,
gluon.nn.Lambda(lambda data: nd.LeakyReLU(data, act_type='elu')),
hyperparams['dropout'],
hyperparams['attn_drop'],
hyperparams['alpha'],
hyperparams['residual'])
if hyperparams['no_features']:
model = NodeEmbeddingGNN(model, in_feats, hyperparams['embedding_size'])
if model_dir:
model.load_parameters(os.path.join(model_dir, 'model.params'))
else:
model.initialize(ctx=ctx)
return model