def invert_torch()

in src/datasets.py [0:0]


def invert_torch(triples: torch.Tensor, n_rel: int, include_type=True):
    """Given triples, return the version containing reciprocal triples, used in valid/test
    
    Args: 
        triples: h, r, t, h_neg, t_neg, h_type, t_type
        n_rel: the number of original relations
    """
    tmp = torch.clone(triples[:, 0])
    triples[:, 0] = triples[:, 2]
    triples[:, 2] = tmp
    triples[:, 1] += n_rel
    del tmp
    if include_type:
        tmp = torch.clone(triples[:, -1]) 
        triples[:, -1] = triples[:, -2]
        triples[:, -2] = tmp  
        num_neg = (triples.shape[1] - 5) // 2
    else:
        num_neg = (triples.shape[1] - 3) // 2
    print('Num neg per head/tail {}'.format(num_neg))
    if num_neg > 0:
        tmp = torch.clone(triples[:, 3:3+num_neg])
        assert tmp.shape[1] == num_neg
        triples[:, 3:3+num_neg] = triples[:, 3+num_neg:3+2*num_neg]
        triples[:, 3+num_neg:3+2*num_neg] = tmp
        del tmp
    return triples