def get_model()

in source/sagemaker/sagemaker_graph_fraud_detection/dgl_fraud_detection/train_dgl_mxnet_entry_point.py [0:0]


def get_model(g, hyperparams, in_feats, n_classes, ctx, model_dir=None):

    if model_dir:  # load using saved model state
        with open(os.path.join(model_dir, 'model_hyperparams.pkl'), 'rb') as f:
            hyperparams = pickle.load(f)
        with open(os.path.join(model_dir, 'graph.pkl'), 'rb') as f:
            g = pickle.load(f)

    if hyperparams['heterogeneous']:
        model = HeteroRGCN(g,
                           in_feats,
                           hyperparams['n_hidden'],
                           n_classes,
                           hyperparams['n_layers'],
                           hyperparams['embedding_size'],
                           ctx)
    else:
        if hyperparams['model'] == 'gcn':
            model = GCN(g,
                        in_feats,
                        hyperparams['n_hidden'],
                        n_classes,
                        hyperparams['n_layers'],
                        nd.relu,
                        hyperparams['dropout'])
        elif hyperparams['model'] == 'graphsage':
            model = GraphSAGE(g,
                              in_feats,
                              hyperparams['n_hidden'],
                              n_classes,
                              hyperparams['n_layers'],
                              nd.relu,
                              hyperparams['dropout'],
                              hyperparams['aggregator_type'])
        else:
            heads = ([hyperparams['num_heads']] * hyperparams['n_layers']) + [hyperparams['num_out_heads']]
            model = GAT(g,
                        in_feats,
                        hyperparams['n_hidden'],
                        n_classes,
                        hyperparams['n_layers'],
                        heads,
                        gluon.nn.Lambda(lambda data: nd.LeakyReLU(data, act_type='elu')),
                        hyperparams['dropout'],
                        hyperparams['attn_drop'],
                        hyperparams['alpha'],
                        hyperparams['residual'])

    if hyperparams['no_features']:
        model = NodeEmbeddingGNN(model, in_feats, hyperparams['embedding_size'])

    if model_dir:
        model.load_parameters(os.path.join(model_dir, 'model.params'))
    else:
        model.initialize(ctx=ctx)

    return model