def einsum()

in pyro/ops/contract.py [0:0]


def einsum(equation, *operands, **kwargs):
    """
    Generalized plated sum-product algorithm via tensor variable elimination.

    This generalizes :func:`~pyro.ops.einsum.contract` in two ways:

    1.  Multiple outputs are allowed, and intermediate results can be shared.
    2.  Inputs and outputs can be plated along symbols given in ``plates``;
        reductions along ``plates`` are product reductions.

    The best way to understand this function is to try the examples below,
    which show how :func:`einsum` calls can be implemented as multiple calls
    to :func:`~pyro.ops.einsum.contract` (which is generally more expensive).

    To illustrate multiple outputs, note that the following are equivalent::

        z1, z2, z3 = einsum('ab,bc->a,b,c', x, y)  # multiple outputs

        z1 = contract('ab,bc->a', x, y)
        z2 = contract('ab,bc->b', x, y)
        z3 = contract('ab,bc->c', x, y)

    To illustrate plated inputs, note that the following are equivalent::

        assert len(x) == 3 and len(y) == 3
        z = einsum('ab,ai,bi->b', w, x, y, plates='i')

        z = contract('ab,a,a,a,b,b,b->b', w, *x, *y)

    When a sum dimension `a` always appears with a plate dimension `i`,
    then `a` corresponds to a distinct symbol for each slice of `a`. Thus
    the following are equivalent::

        assert len(x) == 3 and len(y) == 3
        z = einsum('ai,ai->', x, y, plates='i')

        z = contract('a,b,c,a,b,c->', *x, *y)

    When such a sum dimension appears in the output, it must be
    accompanied by all of its plate dimensions, e.g. the following are
    equivalent::

        assert len(x) == 3 and len(y) == 3
        z = einsum('abi,abi->bi', x, y, plates='i')

        z0 = contract('ab,ac,ad,ab,ac,ad->b', *x, *y)
        z1 = contract('ab,ac,ad,ab,ac,ad->c', *x, *y)
        z2 = contract('ab,ac,ad,ab,ac,ad->d', *x, *y)
        z = torch.stack([z0, z1, z2])

    Note that each plate slice through the output is multilinear in all plate
    slices through all inptus, thus e.g. batch matrix multiply would be
    implemented *without* ``plates``, so the following are all equivalent::

        xy = einsum('abc,acd->abd', x, y, plates='')
        xy = torch.stack([xa.mm(ya) for xa, ya in zip(x, y)])
        xy = torch.bmm(x, y)

    Among all valid equations, some computations are polynomial in the sizes of
    the input tensors and other computations are exponential in the sizes of
    the input tensors. This function raises :py:class:`NotImplementedError`
    whenever the computation is exponential.

    :param str equation: An einsum equation, optionally with multiple outputs.
    :param torch.Tensor operands: A collection of tensors.
    :param str plates: An optional string of plate symbols.
    :param str backend: An optional einsum backend, defaults to 'torch'.
    :param dict cache: An optional :func:`~opt_einsum.shared_intermediates`
        cache.
    :param bool modulo_total: Optionally allow einsum to arbitrarily scale
        each result plate, which can significantly reduce computation. This is
        safe to set whenever each result plate denotes a nonnormalized
        probability distribution whose total is not of interest.
    :return: a tuple of tensors of requested shape, one entry per output.
    :rtype: tuple
    :raises ValueError: if tensor sizes mismatch or an output requests a
        plated dim without that dim's plates.
    :raises NotImplementedError: if contraction would have cost exponential in
        the size of any input tensor.
    """
    # Extract kwargs.
    cache = kwargs.pop('cache', None)
    plates = kwargs.pop('plates', '')
    backend = kwargs.pop('backend', 'torch')
    modulo_total = kwargs.pop('modulo_total', False)
    try:
        Ring = BACKEND_TO_RING[backend]
    except KeyError:
        raise NotImplementedError('\n'.join(
            ['Only the following pyro backends are currently implemented:'] +
            list(BACKEND_TO_RING)))

    # Parse generalized einsum equation.
    if '.' in equation:
        raise NotImplementedError('ubsersum does not yet support ellipsis notation')
    inputs, outputs = equation.split('->')
    inputs = inputs.split(',')
    outputs = outputs.split(',')
    assert len(inputs) == len(operands)
    assert all(isinstance(x, torch.Tensor) for x in operands)
    if not modulo_total and any(outputs):
        raise NotImplementedError('Try setting modulo_total=True and ensuring that your use case '
                                  'allows an arbitrary scale factor on each result plate.')
    if len(operands) != len(set(operands)):
        operands = [x[...] for x in operands]  # ensure tensors are unique

    # Check sizes.
    with ignore_jit_warnings():
        dim_to_size = {}
        for dims, term in zip(inputs, operands):
            for dim, size in zip(dims, map(int, term.shape)):
                old = dim_to_size.setdefault(dim, size)
                if old != size:
                    raise ValueError(u"Dimension size mismatch at dim '{}': {} vs {}"
                                     .format(dim, size, old))

    # Construct a tensor tree shared by all outputs.
    tensor_tree = OrderedDict()
    plates = frozenset(plates)
    for dims, term in zip(inputs, operands):
        assert len(dims) == term.dim()
        term._pyro_dims = dims
        ordinal = plates.intersection(dims)
        tensor_tree.setdefault(ordinal, []).append(term)

    # Compute outputs, sharing intermediate computations.
    results = []
    with shared_intermediates(cache) as cache:
        ring = Ring(cache, dim_to_size=dim_to_size)
        for output in outputs:
            sum_dims = set(output).union(*inputs) - set(plates)
            term = contract_to_tensor(tensor_tree, sum_dims,
                                      target_ordinal=plates.intersection(output),
                                      target_dims=sum_dims.intersection(output),
                                      ring=ring)
            if term._pyro_dims != output:
                term = term.permute(*map(term._pyro_dims.index, output))
                term._pyro_dims = output
            results.append(term)
    return tuple(results)