in src/sagemaker/FD_SL_DGL/gnn_fraud_detection_dgl/fd_sl_train_entry_point.py [0:0]
def train_fg(model, optim, loss, features, labels, train_g, test_g, test_mask,
device, n_epochs, thresh, compute_metrics=True):
"""
A full graph verison of RGCN training
"""
duration = []
for epoch in range(n_epochs):
tic = time.time()
loss_val = 0.
pred = model(train_g, features.to(device))
l = loss(pred, labels)
optim.zero_grad()
l.backward()
optim.step()
loss_val += l
duration.append(time.time() - tic)
metric = evaluate(model, train_g, features, labels, device)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | f1 {:.4f} ".format(
epoch, np.mean(duration), loss_val, metric))
class_preds, pred_proba = get_model_class_predictions(model,
test_g,
features,
labels,
device,
threshold=thresh)
if compute_metrics:
acc, f1, p, r, roc, pr, ap, cm = get_metrics(class_preds, pred_proba, labels.numpy(), test_mask.numpy(), './')
print("Metrics")
print("""Confusion Matrix: