def cat()

in torchbiggraph/edgelist.py [0:0]


    def cat(cls, edge_lists: Sequence["EdgeList"]) -> "EdgeList":
        cat_lhs = EntityList.cat([el.lhs for el in edge_lists])
        cat_rhs = EntityList.cat([el.rhs for el in edge_lists])

        if any(el.has_weight() for el in edge_lists):
            if not all(el.has_weight() for el in edge_lists):
                raise RuntimeError(
                    "Can't concatenate edgelists with and without weight field."
                )
            cat_weight = torch.cat([el.weight.expand((len(el),)) for el in edge_lists])
        else:
            cat_weight = None

        if all(el.has_scalar_relation_type() for el in edge_lists):
            rel_types = {el.get_relation_type_as_scalar() for el in edge_lists}
            if len(rel_types) == 1:
                (rel_type,) = rel_types
                return cls(
                    cat_lhs,
                    cat_rhs,
                    torch.tensor(rel_type, dtype=torch.long),
                    cat_weight,
                )
        cat_rel = torch.cat([el.rel.expand((len(el),)) for el in edge_lists])

        return cls(cat_lhs, cat_rhs, cat_rel, cat_weight)