def train()

in source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/train_dgl_pytorch_entity_resolution.py [0:0]


def train(g, model, train_dataloader, train_triplets, test_triplets, user_features, web_features, optimizer, batch_size,
          n_neighbors, n_epochs, negative_rate, grad_norm, cuda, device=None, run_eval=True):
    for epoch in range(n_epochs):
        tic = time.time()
        duration = []
        loss_val = 0.
        mrr = -1.

        model.train()
        for n, (input_nodes, pos_pair_graph, neg_pair_graph, blocks) in enumerate(train_dataloader):
            user_nodes, website_nodes = input_nodes['user'], input_nodes['website']
            u, w = user_features[input_nodes['user']], web_features[input_nodes['website']]

            true_srcs, true_dsts = pos_pair_graph.all_edges(etype='same_entity')
            false_srcs, false_dsts = neg_pair_graph.all_edges(etype='same_entity')
            sources, sinks = torch.cat((true_srcs, false_srcs)), torch.cat((true_dsts, false_dsts))
            labels = torch.zeros((negative_rate + 1) * len(true_srcs))
            labels[:len(true_srcs)] = 1

            if cuda:
                user_nodes, website_nodes, u, w = user_nodes.cuda(), website_nodes.cuda(), u.cuda(), w.cuda()
                blocks = [blk.to(device) for blk in blocks]
                sources, sinks, labels = sources.cuda(), sinks.cuda(), labels.cuda()
            embeddings = model(blocks,user_nodes, website_nodes, u, w)

            loss = model.get_loss(embeddings, sources, sinks, labels)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)
            optimizer.step()

            loss_val += loss.item()

        duration.append(time.time() - tic)
        do_eval = run_eval and ((epoch %  5 == 0) or (epoch == n_epochs-1))
        if do_eval:
            mrr = evaluate(model, g, train_triplets, test_triplets, user_features, web_features, batch_size, n_neighbors,
                       device=device)

        logging.info("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | MRR {:.4f}".format(
            epoch, np.mean(duration), loss_val / (n + 1), mrr))
    return model