def forward()

in src/models.py [0:0]


    def forward(self, x, score_rhs=True, score_rel=False, score_lhs=False, normalize_rel=False):
        lhs = self.entity(x[:, 0])
        rel = self.relation(x[:, 1])
        rhs = self.entity(x[:, 2]) 

        if score_rhs:
            lhs_proj = torch.matmul(self.core.transpose(0, 2), lhs.transpose(0, 1)).transpose(0, 2) # b, rank_r, rank_e
            rel_proj = rel.view(-1, 1, self.rank_r)
            lhs_proj = torch.bmm(rel_proj, 
                                 self.dropout(lhs_proj)).view(-1, self.rank_e)
            rhs_scores = lhs_proj @ self.entity.weight.t()
        if score_rel:
            lhs_proj = torch.matmul(self.core.transpose(0, 2), lhs.transpose(0, 1)).transpose(0, 2) # b, rank_r, rank_e
            rhs_proj = rhs.view(-1, self.rank_e, 1)
            lr_proj = torch.bmm(self.dropout(lhs_proj), 
                                rhs_proj).view(-1, self.rank_r) # b, rank_r
            rel_scores = lr_proj @ self.relation.weight.t()
        if score_lhs:
            rhs_proj = torch.matmul(self.core, rhs.transpose(0, 1)).transpose(0, 2) # b, rank_r, rank_e
            rel_proj = rel.view(-1, 1, self.rank_r)
            rhs_proj = torch.bmm(rel_proj, 
                                 self.dropout(rhs_proj)).view(-1, self.rank_e)
            lhs_scores = rhs_proj @ self.entity.weight.t()

        factors = (lhs, 
                   rel * ((self.rank_e * 1.0 / self.rank_r) ** (1/3.0)), 
                   rhs) # the rank of relation is smaller than that of entity, so we add some scaling
        if score_rhs and score_rel and score_lhs:
            return (rhs_scores, rel_scores, lhs_scores), factors
        elif score_rhs and score_rel:
            return (rhs_scores, rel_scores), factors
        elif score_lhs and score_rel:
            pass
        elif score_rhs and score_lhs:
            pass
        elif score_rhs:
            return rhs_scores, factors
        elif score_rel:
            return rel_scores, factors
        elif score_lhs:
            return lhs_scores, factors
        else:
            return None