def train()

in source/sagemaker/sagemaker_graph_fraud_detection/dgl_fraud_detection/train_dgl_mxnet_entry_point.py [0:0]


def train(model, trainer, loss, features, labels, train_loader, test_loader, train_g, test_g, train_mask, test_mask,
          ctx, n_epochs, batch_size, output_dir, thresh, scale_pos_weight, compute_metrics=True, mini_batch=True):
    duration = []
    for epoch in range(n_epochs):
        tic = time.time()
        loss_val = 0.

        for n, batch in enumerate(train_loader):
            # logging.info("Iteration: {:05d}".format(n))
            node_flow, batch_nids = train_g.sample_block(nd.array(batch).astype('int64'))
            batch_indices = nd.array(batch, ctx=ctx)
            with autograd.record():
                pred = model(node_flow, features[batch_nids.as_in_context(ctx)])
                l = loss(pred, labels[batch_indices], mx.nd.expand_dims(scale_pos_weight*train_mask, 1)[batch_indices])
                l = l.sum()/len(batch)

            l.backward()
            trainer.step(batch_size=1, ignore_stale_grad=True)

            loss_val += l.asscalar()
            # logging.info("Current loss {:04f}".format(loss_val/(n+1)))

        duration.append(time.time() - tic)
        metric = evaluate(model, train_g, features, labels, train_mask, ctx, batch_size, mini_batch)
        logging.info("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | F1 {:.4f}".format(
                epoch, np.mean(duration), loss_val/(n+1), metric))

    class_preds, pred_proba = get_model_class_predictions(model, test_g, test_loader, features, ctx, threshold=thresh)

    if compute_metrics:
        acc, f1, p, r, roc, pr, ap, cm = get_metrics(class_preds, pred_proba, labels, test_mask, output_dir)
        logging.info("Metrics")
        logging.info("""Confusion Matrix: