in pyro/ops/contract.py [0:0]
def _contract_component(ring, tensor_tree, sum_dims, target_dims):
"""
Contract out ``sum_dims - target_dims`` in a tree of tensors in-place, via
message passing. This reduces all tensors down to a single tensor in the
minimum plate context.
This function should be deterministic.
This function has side-effects: it modifies ``tensor_tree``.
:param pyro.ops.rings.Ring ring: an algebraic ring defining tensor
operations.
:param OrderedDict tensor_tree: a dictionary mapping ordinals to lists of
tensors. An ordinal is a frozenset of ``CondIndepStack`` frames.
:param set sum_dims: the complete set of sum-contractions dimensions
(indexed from the right). This is needed to distinguish sum-contraction
dimensions from product-contraction dimensions.
:param set target_dims: An subset of ``sum_dims`` that should be preserved
in the result.
:return: a pair ``(ordinal, tensor)``
:rtype: tuple of frozenset and torch.Tensor
"""
# Group sum dims by ordinal.
dim_to_ordinal = {}
for t, terms in tensor_tree.items():
for term in terms:
for dim in sum_dims.intersection(term._pyro_dims):
dim_to_ordinal[dim] = dim_to_ordinal.get(dim, t) & t
dims_tree = defaultdict(set)
for dim, t in dim_to_ordinal.items():
dims_tree[t].add(dim)
# Recursively combine terms in different plate contexts.
local_terms = []
local_dims = target_dims.copy()
local_ordinal = frozenset()
min_ordinal = frozenset.intersection(*tensor_tree)
while any(dims_tree.values()):
# Arbitrarily deterministically choose a leaf.
leaf = max(tensor_tree, key=len)
leaf_terms = tensor_tree.pop(leaf)
leaf_dims = dims_tree.pop(leaf, set())
# Split terms at the current ordinal into connected components.
for terms, dims in _partition_terms(ring, leaf_terms, leaf_dims):
# Eliminate sum dims via a sumproduct contraction.
term = ring.sumproduct(terms, dims - local_dims)
# Eliminate extra plate dims via product contractions.
if leaf == min_ordinal:
parent = leaf
else:
pending_dims = sum_dims.intersection(term._pyro_dims)
parent = frozenset.union(*(t for t, d in dims_tree.items() if d & pending_dims))
_check_tree_structure(parent, leaf)
contract_frames = leaf - parent
contract_dims = dims & local_dims
if contract_dims:
term, local_term = ring.global_local(term, contract_dims, contract_frames)
local_terms.append(local_term)
local_dims |= sum_dims.intersection(local_term._pyro_dims)
local_ordinal |= leaf
else:
term = ring.product(term, contract_frames)
tensor_tree.setdefault(parent, []).append(term)
# Extract single tensor at root ordinal.
assert len(tensor_tree) == 1
ordinal, (term,) = tensor_tree.popitem()
assert ordinal == min_ordinal
# Perform optional localizing pass.
if local_terms:
assert target_dims
local_terms.append(term)
term = ring.sumproduct(local_terms, local_dims - target_dims)
ordinal |= local_ordinal
return ordinal, term