in utils.py [0:0]
def get_gnn(n_nodes, gnn_hypers, opt_params, torch_device, torch_dtype):
"""
Generate GNN instance with specified structure. Creates GNN, retrieves embedding layer,
and instantiates ADAM optimizer given those.
Input:
n_nodes: Problem size (number of nodes in graph)
gnn_hypers: Hyperparameters relevant to GNN structure
opt_params: Hyperparameters relevant to ADAM optimizer
torch_device: Whether to load pytorch variables onto CPU or GPU
torch_dtype: Datatype to use for pytorch variables
Output:
net: GNN instance
embed: Embedding layer to use as input to GNN
optimizer: ADAM optimizer instance
"""
dim_embedding = gnn_hypers['dim_embedding']
hidden_dim = gnn_hypers['hidden_dim']
dropout = gnn_hypers['dropout']
number_classes = gnn_hypers['number_classes']
# instantiate the GNN
net = GCN_dev(dim_embedding, hidden_dim, number_classes, dropout, torch_device)
net = net.type(torch_dtype).to(torch_device)
embed = nn.Embedding(n_nodes, dim_embedding)
embed = embed.type(torch_dtype).to(torch_device)
# set up Adam optimizer
params = chain(net.parameters(), embed.parameters())
optimizer = torch.optim.Adam(params, **opt_params)
return net, embed, optimizer