in source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/train_dgl_pytorch_entity_resolution.py [0:0]
def train(g, model, train_dataloader, train_triplets, test_triplets, user_features, web_features, optimizer, batch_size,
n_neighbors, n_epochs, negative_rate, grad_norm, cuda, device=None, run_eval=True):
for epoch in range(n_epochs):
tic = time.time()
duration = []
loss_val = 0.
mrr = -1.
model.train()
for n, (input_nodes, pos_pair_graph, neg_pair_graph, blocks) in enumerate(train_dataloader):
user_nodes, website_nodes = input_nodes['user'], input_nodes['website']
u, w = user_features[input_nodes['user']], web_features[input_nodes['website']]
true_srcs, true_dsts = pos_pair_graph.all_edges(etype='same_entity')
false_srcs, false_dsts = neg_pair_graph.all_edges(etype='same_entity')
sources, sinks = torch.cat((true_srcs, false_srcs)), torch.cat((true_dsts, false_dsts))
labels = torch.zeros((negative_rate + 1) * len(true_srcs))
labels[:len(true_srcs)] = 1
if cuda:
user_nodes, website_nodes, u, w = user_nodes.cuda(), website_nodes.cuda(), u.cuda(), w.cuda()
blocks = [blk.to(device) for blk in blocks]
sources, sinks, labels = sources.cuda(), sinks.cuda(), labels.cuda()
embeddings = model(blocks,user_nodes, website_nodes, u, w)
loss = model.get_loss(embeddings, sources, sinks, labels)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)
optimizer.step()
loss_val += loss.item()
duration.append(time.time() - tic)
do_eval = run_eval and ((epoch % 5 == 0) or (epoch == n_epochs-1))
if do_eval:
mrr = evaluate(model, g, train_triplets, test_triplets, user_features, web_features, batch_size, n_neighbors,
device=device)
logging.info("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | MRR {:.4f}".format(
epoch, np.mean(duration), loss_val / (n + 1), mrr))
return model