in src/models.py [0:0]
def forward(self, x, score_rhs=True, score_rel=False, score_lhs=False, normalize_rel=False):
lhs = self.lhs(x[:, 0])
rel = self.rel(x[:, 1])
rhs = self.rhs(x[:, 2])
rhs_scores, rel_scores = None, None
if score_rhs:
rhs_scores = (lhs * rel) @ self.rhs.weight.t()
if score_rel:
rel_scores = (lhs * rhs) @ self.rel.weight.t()
if score_lhs:
lhs_scores = (rhs * rel) @ self.lhs.weight.t()
factors = self.get_factor(x)
if score_rhs and score_rel and score_lhs:
return (rhs_scores, rel_scores, lhs_scores), factors
elif score_rhs and score_rel:
return (rhs_scores, rel_scores), factors
elif score_lhs and score_rel:
pass
elif score_rhs and score_lhs:
pass
elif score_rhs:
return rhs_scores, factors
elif score_rel:
return rel_scores, factors
elif score_lhs:
return lhs_scores, factors
else:
return None