def _contract_component()

in pyro/ops/contract.py [0:0]


def _contract_component(ring, tensor_tree, sum_dims, target_dims):
    """
    Contract out ``sum_dims - target_dims`` in a tree of tensors in-place, via
    message passing. This reduces all tensors down to a single tensor in the
    minimum plate context.

    This function should be deterministic.
    This function has side-effects: it modifies ``tensor_tree``.

    :param pyro.ops.rings.Ring ring: an algebraic ring defining tensor
        operations.
    :param OrderedDict tensor_tree: a dictionary mapping ordinals to lists of
        tensors. An ordinal is a frozenset of ``CondIndepStack`` frames.
    :param set sum_dims: the complete set of sum-contractions dimensions
        (indexed from the right). This is needed to distinguish sum-contraction
        dimensions from product-contraction dimensions.
    :param set target_dims: An subset of ``sum_dims`` that should be preserved
        in the result.
    :return: a pair ``(ordinal, tensor)``
    :rtype: tuple of frozenset and torch.Tensor
    """
    # Group sum dims by ordinal.
    dim_to_ordinal = {}
    for t, terms in tensor_tree.items():
        for term in terms:
            for dim in sum_dims.intersection(term._pyro_dims):
                dim_to_ordinal[dim] = dim_to_ordinal.get(dim, t) & t
    dims_tree = defaultdict(set)
    for dim, t in dim_to_ordinal.items():
        dims_tree[t].add(dim)

    # Recursively combine terms in different plate contexts.
    local_terms = []
    local_dims = target_dims.copy()
    local_ordinal = frozenset()
    min_ordinal = frozenset.intersection(*tensor_tree)
    while any(dims_tree.values()):
        # Arbitrarily deterministically choose a leaf.
        leaf = max(tensor_tree, key=len)
        leaf_terms = tensor_tree.pop(leaf)
        leaf_dims = dims_tree.pop(leaf, set())

        # Split terms at the current ordinal into connected components.
        for terms, dims in _partition_terms(ring, leaf_terms, leaf_dims):

            # Eliminate sum dims via a sumproduct contraction.
            term = ring.sumproduct(terms, dims - local_dims)

            # Eliminate extra plate dims via product contractions.
            if leaf == min_ordinal:
                parent = leaf
            else:
                pending_dims = sum_dims.intersection(term._pyro_dims)
                parent = frozenset.union(*(t for t, d in dims_tree.items() if d & pending_dims))
                _check_tree_structure(parent, leaf)
                contract_frames = leaf - parent
                contract_dims = dims & local_dims
                if contract_dims:
                    term, local_term = ring.global_local(term, contract_dims, contract_frames)
                    local_terms.append(local_term)
                    local_dims |= sum_dims.intersection(local_term._pyro_dims)
                    local_ordinal |= leaf
                else:
                    term = ring.product(term, contract_frames)
            tensor_tree.setdefault(parent, []).append(term)

    # Extract single tensor at root ordinal.
    assert len(tensor_tree) == 1
    ordinal, (term,) = tensor_tree.popitem()
    assert ordinal == min_ordinal

    # Perform optional localizing pass.
    if local_terms:
        assert target_dims
        local_terms.append(term)
        term = ring.sumproduct(local_terms, local_dims - target_dims)
        ordinal |= local_ordinal

    return ordinal, term