def train()

in source/sagemaker/baseline/train_pytorch_mlp_entity_resolution.py [0:0]


def train(model, dataloader, features, n_epochs, optimizer, neg_rate, cuda):
    for epoch in range(n_epochs):
        tic = time.time()
        loss_val = 0.
        duration = []
        metric = -1
        for n, (i, j) in enumerate(dataloader):
            labels = torch.zeros((neg_rate + 1) * len(i))
            labels[:len(i)] = 1
            i = torch.cat((i, torch.tensor(np.random.choice(features.shape[0], neg_rate*len(i)))))
            j = torch.cat((j, torch.tensor(np.random.choice(features.shape[0], neg_rate*len(j)))))

            if cuda:
                i, j, labels = i.cuda(), j.cuda(), labels.cuda()

            embed = model(features)
            loss = model.get_loss(embed, i, j, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_val += loss.item()
            duration.append(time.time() - tic)
        print(loss_val)
        logging.info("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | MRR {:.4f}".format(
            epoch, np.mean(duration), loss_val / (n + 1), metric))