in torchbiggraph/edgelist.py [0:0]
def cat(cls, edge_lists: Sequence["EdgeList"]) -> "EdgeList":
cat_lhs = EntityList.cat([el.lhs for el in edge_lists])
cat_rhs = EntityList.cat([el.rhs for el in edge_lists])
if any(el.has_weight() for el in edge_lists):
if not all(el.has_weight() for el in edge_lists):
raise RuntimeError(
"Can't concatenate edgelists with and without weight field."
)
cat_weight = torch.cat([el.weight.expand((len(el),)) for el in edge_lists])
else:
cat_weight = None
if all(el.has_scalar_relation_type() for el in edge_lists):
rel_types = {el.get_relation_type_as_scalar() for el in edge_lists}
if len(rel_types) == 1:
(rel_type,) = rel_types
return cls(
cat_lhs,
cat_rhs,
torch.tensor(rel_type, dtype=torch.long),
cat_weight,
)
cat_rel = torch.cat([el.rel.expand((len(el),)) for el in edge_lists])
return cls(cat_lhs, cat_rhs, cat_rel, cat_weight)