in torchbiggraph/util.py [0:0]
def __init__(self, config: ConfigSchema) -> None:
(
self.nparts_lhs,
self.lhs_unpartitioned_types,
self.lhs_partitioned_types,
) = get_partitioned_types( # noqa
config, Side.LHS
)
(
self.nparts_rhs,
self.rhs_unpartitioned_types,
self.rhs_partitioned_types,
) = get_partitioned_types( # noqa
config, Side.RHS
)
if self.nparts_lhs == 1 and self.nparts_rhs == 1:
assert (
config.num_machines == 1
), "Cannot run distributed training with a single partition."
self.lhs_partitioned_types = self.lhs_unpartitioned_types
self.rhs_partitioned_types = self.rhs_unpartitioned_types
self.lhs_unpartitioned_types = set()
self.rhs_unpartitioned_types = set()
self.unpartitioned_embeddings: Dict[EntityName, torch.nn.Parameter] = {}
self.partitioned_embeddings: Dict[
Tuple[EntityName, Partition], torch.nn.Parameter
] = {}