in utils.py [0:0]
def run_gnn_training(q_torch, dgl_graph, net, embed, optimizer, number_epochs, tol, patience, prob_threshold):
"""
Wrapper function to run and monitor GNN training. Includes early stopping.
"""
# Assign variable for user reference
inputs = embed.weight
prev_loss = 1. # initial loss value (arbitrary)
count = 0 # track number times early stopping is triggered
# initialize optimal solution
best_bitstring = torch.zeros((dgl_graph.number_of_nodes(),)).type(q_torch.dtype).to(q_torch.device)
best_loss = loss_func(best_bitstring.float(), q_torch)
t_gnn_start = time()
# Training logic
for epoch in range(number_epochs):
# get logits/activations
probs = net(dgl_graph, inputs)[:, 0] # collapse extra dimension output from model
# build cost value with QUBO cost function
loss = loss_func(probs, q_torch)
loss_ = loss.detach().item()
# Apply projection
bitstring = (probs.detach() >= prob_threshold) * 1
if loss < best_loss:
best_loss = loss
best_bitstring = bitstring
if epoch % 1000 == 0:
print(f'Epoch: {epoch}, Loss: {loss_}')
# early stopping check
# If loss increases or change in loss is too small, trigger
if (abs(loss_ - prev_loss) <= tol) | ((loss_ - prev_loss) > 0):
count += 1
else:
count = 0
if count >= patience:
print(f'Stopping early on epoch {epoch} (patience: {patience})')
break
# update loss tracking
prev_loss = loss_
# run optimization with backpropagation
optimizer.zero_grad() # clear gradient for step
loss.backward() # calculate gradient through compute graph
optimizer.step() # take step, update weights
t_gnn = time() - t_gnn_start
print(f'GNN training (n={dgl_graph.number_of_nodes()}) took {round(t_gnn, 3)}')
print(f'GNN final continuous loss: {loss_}')
print(f'GNN best continuous loss: {best_loss}')
final_bitstring = (probs.detach() >= prob_threshold) * 1
return net, epoch, final_bitstring, best_bitstring