in torchbiggraph/model.py [0:0]
def forward(self, edges: EdgeList) -> Scores:
num_pos = len(edges)
chunk_size: int
lhs_negatives: Negatives
lhs_num_uniform_negs: int
rhs_negatives: Negatives
rhs_num_uniform_negs: int
if self.num_dynamic_rels > 0:
if edges.has_scalar_relation_type():
raise TypeError("Need relation for each positive pair")
relation_idx = 0
else:
if not edges.has_scalar_relation_type():
raise TypeError("All positive pairs must come from the same relation")
relation_idx = edges.get_relation_type_as_scalar()
relation = self.relations[relation_idx]
lhs_module: AbstractEmbedding = self.lhs_embs[self.EMB_PREFIX + relation.lhs]
rhs_module: AbstractEmbedding = self.rhs_embs[self.EMB_PREFIX + relation.rhs]
lhs_pos: FloatTensorType = lhs_module(edges.lhs)
rhs_pos: FloatTensorType = rhs_module(edges.rhs)
if relation.all_negs:
chunk_size = num_pos
negative_sampling_method = Negatives.ALL
elif self.num_batch_negs == 0:
chunk_size = min(self.num_uniform_negs, num_pos)
negative_sampling_method = Negatives.UNIFORM
else:
chunk_size = min(self.num_batch_negs, num_pos)
negative_sampling_method = Negatives.BATCH_UNIFORM
lhs_negative_sampling_method = negative_sampling_method
rhs_negative_sampling_method = negative_sampling_method
if self.disable_lhs_negs:
lhs_negative_sampling_method = Negatives.NONE
if self.disable_rhs_negs:
rhs_negative_sampling_method = Negatives.NONE
if self.num_dynamic_rels == 0:
# In this case the operator is only applied to the RHS. This means
# that an edge (u, r, v) is scored with c(u, f_r(v)), whereas the
# negatives (u', r, v) and (u, r, v') are scored respectively with
# c(u', f_r(v)) and c(u, f_r(v')). Since r is always the same, each
# positive and negative right-hand side entity is only passed once
# through the operator.
if self.lhs_operators[relation_idx] is not None:
raise RuntimeError(
"In non-dynamic relation mode there should "
"be only a right-hand side operator"
)
# Apply operator to right-hand side, sample negatives on both sides unless
# one side is disabled.
(
pos_scores,
lhs_neg_scores,
rhs_neg_scores,
reg,
) = self.forward_direction_agnostic( # noqa
edges.lhs,
edges.rhs,
edges.get_relation_type(),
relation.lhs,
relation.rhs,
None,
self.rhs_operators[relation_idx],
lhs_module,
rhs_module,
lhs_pos,
rhs_pos,
chunk_size,
lhs_negative_sampling_method,
rhs_negative_sampling_method,
)
lhs_pos_scores = rhs_pos_scores = pos_scores
else:
# In this case the positive edges may come from different relations.
# This makes it inefficient to apply the operators to the negatives
# in the way we do above, because for a negative edge (u, r, v') we
# would need to compute f_r(v'), with r being different from the one
# in any positive pair that has v' on the right-hand side, which
# could lead to v being passed through many different (potentially
# all) operators. This would result in a combinatorial explosion.
# So, instead, we duplicate all operators, creating two versions of
# them, one for each side, and only allow one of them to be applied
# at any given time. The edge (u, r, v) can thus be scored in two
# ways, either as c(g_r(u), v) or as c(u, h_r(v)). The negatives
# (u', r, v) and (u, r, v') are scored respectively as c(u', h_r(v))
# and c(g_r(u), v'). This way we only need to perform two operator
# applications for every positive input edge, one for each side.
# "Forward" edges: apply operator to rhs, sample negatives on lhs.
lhs_pos_scores, lhs_neg_scores, _, l_reg = self.forward_direction_agnostic(
edges.lhs,
edges.rhs,
edges.get_relation_type(),
relation.lhs,
relation.rhs,
None,
self.rhs_operators[relation_idx],
lhs_module,
rhs_module,
lhs_pos,
rhs_pos,
chunk_size,
lhs_negative_sampling_method,
Negatives.NONE,
)
# "Reverse" edges: apply operator to lhs, sample negatives on rhs.
rhs_pos_scores, rhs_neg_scores, _, r_reg = self.forward_direction_agnostic(
edges.rhs,
edges.lhs,
edges.get_relation_type(),
relation.rhs,
relation.lhs,
None,
self.lhs_operators[relation_idx],
rhs_module,
lhs_module,
rhs_pos,
lhs_pos,
chunk_size,
rhs_negative_sampling_method,
Negatives.NONE,
)
if r_reg is None or l_reg is None:
reg = None
else:
reg = l_reg + r_reg
return (
Scores(lhs_pos_scores, rhs_pos_scores, lhs_neg_scores, rhs_neg_scores),
reg,
)