in source/sagemaker/sagemaker_graph_fraud_detection/dgl_fraud_detection/train_dgl_mxnet_entry_point.py [0:0]
def train(model, trainer, loss, features, labels, train_loader, test_loader, train_g, test_g, train_mask, test_mask,
ctx, n_epochs, batch_size, output_dir, thresh, scale_pos_weight, compute_metrics=True, mini_batch=True):
duration = []
for epoch in range(n_epochs):
tic = time.time()
loss_val = 0.
for n, batch in enumerate(train_loader):
# logging.info("Iteration: {:05d}".format(n))
node_flow, batch_nids = train_g.sample_block(nd.array(batch).astype('int64'))
batch_indices = nd.array(batch, ctx=ctx)
with autograd.record():
pred = model(node_flow, features[batch_nids.as_in_context(ctx)])
l = loss(pred, labels[batch_indices], mx.nd.expand_dims(scale_pos_weight*train_mask, 1)[batch_indices])
l = l.sum()/len(batch)
l.backward()
trainer.step(batch_size=1, ignore_stale_grad=True)
loss_val += l.asscalar()
# logging.info("Current loss {:04f}".format(loss_val/(n+1)))
duration.append(time.time() - tic)
metric = evaluate(model, train_g, features, labels, train_mask, ctx, batch_size, mini_batch)
logging.info("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | F1 {:.4f}".format(
epoch, np.mean(duration), loss_val/(n+1), metric))
class_preds, pred_proba = get_model_class_predictions(model, test_g, test_loader, features, ctx, threshold=thresh)
if compute_metrics:
acc, f1, p, r, roc, pr, ap, cm = get_metrics(class_preds, pred_proba, labels, test_mask, output_dir)
logging.info("Metrics")
logging.info("""Confusion Matrix: