in src/models.py [0:0]
def forward(self, x, score_rhs=True, score_rel=False, score_lhs=False, normalize_rel=False):
lhs = self.entity(x[:, 0])
rel = self.relation(x[:, 1])
rhs = self.entity(x[:, 2])
rel = rel.view(-1, self.rank, self.rank)
if score_rhs:
lhs_proj = lhs.view(-1, 1, self.rank)
lhs_proj = torch.bmm(lhs_proj, rel).view(-1, self.rank)
rhs_scores = lhs_proj @ self.entity.weight.t()
if score_rel:
lhs_proj = lhs.view(-1, self.rank, 1)
rhs_proj = rhs.view(-1, 1, self.rank)
lr_proj = torch.bmm(lhs_proj, rhs_proj).view(-1, self.rank * self.rank)
rel_scores = lr_proj @ self.relation.weight.t()
if score_lhs:
rhs_proj = rhs.view(-1, 1, self.rank)
rhs_proj = torch.bmm(rhs_proj, rel.transpose(1, 2)).view(-1, self.rank)
lhs_scores = rhs_proj @ self.entity.weight.t()
# factors = (lhs, rel, rhs) if not normalize_rel else
factors = (lhs, rel / (self.rank ** (1/3.0)), rhs) # scaling factor for N3
if score_rhs and score_rel and score_lhs:
return (rhs_scores, rel_scores, lhs_scores), factors
elif score_rhs and score_rel:
return (rhs_scores, rel_scores), factors
elif score_lhs and score_rel:
pass
elif score_rhs and score_lhs:
return (rhs_scores, lhs_scores), factors
elif score_rhs:
return rhs_scores, factors
elif score_rel:
return rel_scores, factors
elif score_lhs:
return lhs_scores, factors
else:
return None