in src/models.py [0:0]
def forward(self, x, score_rhs=True, score_rel=False, score_lhs=False, normalize_rel=False):
lhs = self.entity(x[:, 0])
rel = self.relation(x[:, 1])
rhs = self.entity(x[:, 2])
if score_rhs:
lhs_proj = torch.matmul(self.core.transpose(0, 2), lhs.transpose(0, 1)).transpose(0, 2) # b, rank_r, rank_e
rel_proj = rel.view(-1, 1, self.rank_r)
lhs_proj = torch.bmm(rel_proj,
self.dropout(lhs_proj)).view(-1, self.rank_e)
rhs_scores = lhs_proj @ self.entity.weight.t()
if score_rel:
lhs_proj = torch.matmul(self.core.transpose(0, 2), lhs.transpose(0, 1)).transpose(0, 2) # b, rank_r, rank_e
rhs_proj = rhs.view(-1, self.rank_e, 1)
lr_proj = torch.bmm(self.dropout(lhs_proj),
rhs_proj).view(-1, self.rank_r) # b, rank_r
rel_scores = lr_proj @ self.relation.weight.t()
if score_lhs:
rhs_proj = torch.matmul(self.core, rhs.transpose(0, 1)).transpose(0, 2) # b, rank_r, rank_e
rel_proj = rel.view(-1, 1, self.rank_r)
rhs_proj = torch.bmm(rel_proj,
self.dropout(rhs_proj)).view(-1, self.rank_e)
lhs_scores = rhs_proj @ self.entity.weight.t()
factors = (lhs,
rel * ((self.rank_e * 1.0 / self.rank_r) ** (1/3.0)),
rhs) # the rank of relation is smaller than that of entity, so we add some scaling
if score_rhs and score_rel and score_lhs:
return (rhs_scores, rel_scores, lhs_scores), factors
elif score_rhs and score_rel:
return (rhs_scores, rel_scores), factors
elif score_lhs and score_rel:
pass
elif score_rhs and score_lhs:
pass
elif score_rhs:
return rhs_scores, factors
elif score_rel:
return rel_scores, factors
elif score_lhs:
return lhs_scores, factors
else:
return None