in src/models.py [0:0]
def forward(self, x, score_rhs=True, score_rel=False, score_lhs=False):
lhs = self.entity(x[:, 0])
rel = self.relation(x[:, 1])
rhs = self.entity(x[:, 2])
if score_rhs:
lhs_proj = lhs + rel
# compute - (lhs_proj - rhs) ** 2 = 2 lhs_proj * rhs - rhs ** 2 - lhs_proj ** 2
# tmp1 = 2 * lhs_proj @ self.entity.weight.t()
# tmp2 = torch.norm(lhs_proj, dim=1, p=2).unsqueeze(1)
# tmp3 = torch.norm(self.entity.weight, dim=1, p=2).unsqueeze(0)
# rhs_scores = tmp1 - tmp2 - tmp3
rhs_scores = (2 * lhs_proj @ self.entity.weight.t()
- torch.sum(lhs_proj * lhs_proj, dim=1).unsqueeze(1)
- torch.sum(self.entity.weight * self.entity.weight, dim=1).unsqueeze(0))
if score_lhs:
rhs_proj = rel - rhs
# compute - (lhs + rhs_proj) ** 2 = -2 rhs_proj * lhs - lhs ** 2 - rhs_proj ** 2
# tmp1 = -2 * rhs_proj @ self.entity.weight.t()
# tmp2 = torch.norm(rhs_proj, dim=1, p=2).unsqueeze(1)
# tmp3 = torch.norm(self.entity.weight, dim=1, p=2).unsqueeze(0)
lhs_scores = (-2 * rhs_proj @ self.entity.weight.t()
- torch.sum(rhs_proj * rhs_proj, dim=1).unsqueeze(1)
- torch.sum(self.entity.weight * self.entity.weight, dim=1).unsqueeze(0))
if score_rel:
lr_proj = lhs - rhs
# compute - (lr_proj + rel) ** 2 = -2 lr_proj * rel - rel ** 2 - lr_proj ** 2
# tmp1 = -2 * lr_proj @ self.relation.weight.t()
# tmp2 = torch.norm(lr_proj, dim=1, p=2).unsqueeze(1)
# tmp3 = torch.norm(self.relation.weight, dim=1, p=2).unsqueeze(0)
# rel_scores = tmp1 - tmp2 -tmp3
rel_scores = (-2 * lr_proj @ self.relation.weight.t()
- torch.sum(lr_proj * lr_proj, dim=1).unsqueeze(1)
- torch.sum(self.relation.weight * self.relation.weight, dim=1).unsqueeze(0))
factors = (lhs, rel, rhs)
if score_rhs and score_rel and score_lhs:
return (rhs_scores, rel_scores, lhs_scores), factors
elif score_rhs and score_rel:
return (rhs_scores, rel_scores), factors
elif score_lhs and score_rel:
pass
elif score_rhs and score_lhs:
pass
elif score_rhs:
return rhs_scores, factors
elif score_rel:
return rel_scores, factors
elif score_lhs:
return lhs_scores, factors
else:
return None