def forward()

in src/models.py [0:0]


    def forward(self, x, score_rhs=True, score_rel=False, score_lhs=False):
        lhs = self.entity(x[:, 0])
        rel = self.relation(x[:, 1])
        rhs = self.entity(x[:, 2])

        if score_rhs:
            lhs_proj = lhs + rel 
            # compute - (lhs_proj - rhs) ** 2 = 2 lhs_proj * rhs - rhs ** 2 - lhs_proj ** 2
            # tmp1 = 2 * lhs_proj @ self.entity.weight.t()
            # tmp2 = torch.norm(lhs_proj, dim=1, p=2).unsqueeze(1)
            # tmp3 = torch.norm(self.entity.weight, dim=1, p=2).unsqueeze(0)
            # rhs_scores = tmp1 - tmp2 - tmp3
            rhs_scores = (2 * lhs_proj @ self.entity.weight.t()
                          - torch.sum(lhs_proj * lhs_proj, dim=1).unsqueeze(1)
                          - torch.sum(self.entity.weight * self.entity.weight, dim=1).unsqueeze(0))

        if score_lhs:
            rhs_proj = rel - rhs
            # compute - (lhs + rhs_proj) ** 2 = -2 rhs_proj * lhs - lhs ** 2 - rhs_proj ** 2
            # tmp1 = -2 * rhs_proj @ self.entity.weight.t()
            # tmp2 = torch.norm(rhs_proj, dim=1, p=2).unsqueeze(1)
            # tmp3 = torch.norm(self.entity.weight, dim=1, p=2).unsqueeze(0)
            lhs_scores = (-2 * rhs_proj @ self.entity.weight.t()
                          - torch.sum(rhs_proj * rhs_proj, dim=1).unsqueeze(1)
                          - torch.sum(self.entity.weight * self.entity.weight, dim=1).unsqueeze(0))
    
        if score_rel:
            lr_proj = lhs - rhs
            # compute - (lr_proj + rel) ** 2 = -2 lr_proj * rel - rel ** 2 - lr_proj ** 2
            # tmp1 = -2 * lr_proj @ self.relation.weight.t()
            # tmp2 = torch.norm(lr_proj, dim=1, p=2).unsqueeze(1)
            # tmp3 = torch.norm(self.relation.weight, dim=1, p=2).unsqueeze(0)
            # rel_scores = tmp1 - tmp2 -tmp3
            rel_scores = (-2 * lr_proj @ self.relation.weight.t()
                          - torch.sum(lr_proj * lr_proj, dim=1).unsqueeze(1)
                          - torch.sum(self.relation.weight * self.relation.weight, dim=1).unsqueeze(0))

        factors = (lhs, rel, rhs)
        if score_rhs and score_rel and score_lhs:
            return (rhs_scores, rel_scores, lhs_scores), factors
        elif score_rhs and score_rel:
            return (rhs_scores, rel_scores), factors
        elif score_lhs and score_rel:
            pass
        elif score_rhs and score_lhs:
            pass
        elif score_rhs:
            return rhs_scores, factors
        elif score_rel:
            return rel_scores, factors
        elif score_lhs:
            return lhs_scores, factors
        else:
            return None