in src/datasets.py [0:0]
def batchify(self, batch_size, device, num_neg=None):
if self.is_empty():
self.data = self.data[torch.randperm(self.data.shape[0]), :]
self._idx = 0
self._epoch_idx += 1
if num_neg is None:
batch = self.data[self._idx: self._idx + batch_size].to(device)
self._idx = self._idx + batch_size
return batch
else:
batch_size = int(batch_size / (2 * num_neg))
pos_batch = self.data[self._idx: self._idx + batch_size]
pos_batch = pos_batch.repeat(num_neg, 1).to(device)
neg_batch = pos_batch.clone()
n = pos_batch.shape[0] # batch_size * num_neg
neg_entity = torch.randint(high=self.n_ent - 1, low=0, size=(n,), device=device)
neg_batch[:, 2] = neg_entity
label = torch.ones(n, 1).to(device)
self._idx = self._idx + batch_size
return pos_batch, neg_batch, label