in src/datasets.py [0:0]
def invert_torch(triples: torch.Tensor, n_rel: int, include_type=True):
"""Given triples, return the version containing reciprocal triples, used in valid/test
Args:
triples: h, r, t, h_neg, t_neg, h_type, t_type
n_rel: the number of original relations
"""
tmp = torch.clone(triples[:, 0])
triples[:, 0] = triples[:, 2]
triples[:, 2] = tmp
triples[:, 1] += n_rel
del tmp
if include_type:
tmp = torch.clone(triples[:, -1])
triples[:, -1] = triples[:, -2]
triples[:, -2] = tmp
num_neg = (triples.shape[1] - 5) // 2
else:
num_neg = (triples.shape[1] - 3) // 2
print('Num neg per head/tail {}'.format(num_neg))
if num_neg > 0:
tmp = torch.clone(triples[:, 3:3+num_neg])
assert tmp.shape[1] == num_neg
triples[:, 3:3+num_neg] = triples[:, 3+num_neg:3+2*num_neg]
triples[:, 3+num_neg:3+2*num_neg] = tmp
del tmp
return triples