in pyro/ops/contract.py [0:0]
def _partition_terms(ring, terms, dims):
"""
Given a list of terms and a set of contraction dims, partitions the terms
up into sets that must be contracted together. By separating these
components we avoid broadcasting.
This function should be deterministic and free of side effects.
"""
# Construct a bipartite graph between terms and the dims in which they
# are enumerated. This conflates terms and dims (tensors and ints).
neighbors = OrderedDict([(t, []) for t in terms] + [(d, []) for d in sorted(dims)])
for term in terms:
for dim in term._pyro_dims:
if dim in dims:
neighbors[term].append(dim)
neighbors[dim].append(term)
# Partition the bipartite graph into connected components for contraction.
components = []
while neighbors:
v, pending = neighbors.popitem()
component = OrderedDict([(v, None)]) # used as an OrderedSet
for v in pending:
component[v] = None
while pending:
v = pending.pop()
for v in neighbors.pop(v):
if v not in component:
component[v] = None
pending.append(v)
# Split this connected component into tensors and dims.
component_terms = [v for v in component if isinstance(v, torch.Tensor)]
if component_terms:
component_dims = set(v for v in component if not isinstance(v, torch.Tensor))
components.append((component_terms, component_dims))
return components