def batchify()

in src/datasets.py [0:0]


    def batchify(self, batch_size, device, num_neg=None):
        if self.is_empty():
            self.data = self.data[torch.randperm(self.data.shape[0]), :]
            self._idx = 0
            self._epoch_idx += 1
        if num_neg is None:
            batch = self.data[self._idx: self._idx + batch_size].to(device)
            self._idx = self._idx + batch_size
            return batch
        else:
            batch_size = int(batch_size / (2 * num_neg))
            pos_batch = self.data[self._idx: self._idx + batch_size]
            pos_batch = pos_batch.repeat(num_neg, 1).to(device)
            neg_batch = pos_batch.clone() 
            n = pos_batch.shape[0] # batch_size * num_neg
            neg_entity = torch.randint(high=self.n_ent - 1, low=0, size=(n,), device=device)
            neg_batch[:, 2] = neg_entity 
            label = torch.ones(n, 1).to(device)
            self._idx = self._idx + batch_size
            return pos_batch, neg_batch, label