in source/sagemaker/baseline/train_pytorch_mlp_entity_resolution.py [0:0]
def train(model, dataloader, features, n_epochs, optimizer, neg_rate, cuda):
for epoch in range(n_epochs):
tic = time.time()
loss_val = 0.
duration = []
metric = -1
for n, (i, j) in enumerate(dataloader):
labels = torch.zeros((neg_rate + 1) * len(i))
labels[:len(i)] = 1
i = torch.cat((i, torch.tensor(np.random.choice(features.shape[0], neg_rate*len(i)))))
j = torch.cat((j, torch.tensor(np.random.choice(features.shape[0], neg_rate*len(j)))))
if cuda:
i, j, labels = i.cuda(), j.cuda(), labels.cuda()
embed = model(features)
loss = model.get_loss(embed, i, j, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_val += loss.item()
duration.append(time.time() - tic)
print(loss_val)
logging.info("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | MRR {:.4f}".format(
epoch, np.mean(duration), loss_val / (n + 1), metric))