in pyro/ops/contract.py [0:0]
def einsum(equation, *operands, **kwargs):
"""
Generalized plated sum-product algorithm via tensor variable elimination.
This generalizes :func:`~pyro.ops.einsum.contract` in two ways:
1. Multiple outputs are allowed, and intermediate results can be shared.
2. Inputs and outputs can be plated along symbols given in ``plates``;
reductions along ``plates`` are product reductions.
The best way to understand this function is to try the examples below,
which show how :func:`einsum` calls can be implemented as multiple calls
to :func:`~pyro.ops.einsum.contract` (which is generally more expensive).
To illustrate multiple outputs, note that the following are equivalent::
z1, z2, z3 = einsum('ab,bc->a,b,c', x, y) # multiple outputs
z1 = contract('ab,bc->a', x, y)
z2 = contract('ab,bc->b', x, y)
z3 = contract('ab,bc->c', x, y)
To illustrate plated inputs, note that the following are equivalent::
assert len(x) == 3 and len(y) == 3
z = einsum('ab,ai,bi->b', w, x, y, plates='i')
z = contract('ab,a,a,a,b,b,b->b', w, *x, *y)
When a sum dimension `a` always appears with a plate dimension `i`,
then `a` corresponds to a distinct symbol for each slice of `a`. Thus
the following are equivalent::
assert len(x) == 3 and len(y) == 3
z = einsum('ai,ai->', x, y, plates='i')
z = contract('a,b,c,a,b,c->', *x, *y)
When such a sum dimension appears in the output, it must be
accompanied by all of its plate dimensions, e.g. the following are
equivalent::
assert len(x) == 3 and len(y) == 3
z = einsum('abi,abi->bi', x, y, plates='i')
z0 = contract('ab,ac,ad,ab,ac,ad->b', *x, *y)
z1 = contract('ab,ac,ad,ab,ac,ad->c', *x, *y)
z2 = contract('ab,ac,ad,ab,ac,ad->d', *x, *y)
z = torch.stack([z0, z1, z2])
Note that each plate slice through the output is multilinear in all plate
slices through all inptus, thus e.g. batch matrix multiply would be
implemented *without* ``plates``, so the following are all equivalent::
xy = einsum('abc,acd->abd', x, y, plates='')
xy = torch.stack([xa.mm(ya) for xa, ya in zip(x, y)])
xy = torch.bmm(x, y)
Among all valid equations, some computations are polynomial in the sizes of
the input tensors and other computations are exponential in the sizes of
the input tensors. This function raises :py:class:`NotImplementedError`
whenever the computation is exponential.
:param str equation: An einsum equation, optionally with multiple outputs.
:param torch.Tensor operands: A collection of tensors.
:param str plates: An optional string of plate symbols.
:param str backend: An optional einsum backend, defaults to 'torch'.
:param dict cache: An optional :func:`~opt_einsum.shared_intermediates`
cache.
:param bool modulo_total: Optionally allow einsum to arbitrarily scale
each result plate, which can significantly reduce computation. This is
safe to set whenever each result plate denotes a nonnormalized
probability distribution whose total is not of interest.
:return: a tuple of tensors of requested shape, one entry per output.
:rtype: tuple
:raises ValueError: if tensor sizes mismatch or an output requests a
plated dim without that dim's plates.
:raises NotImplementedError: if contraction would have cost exponential in
the size of any input tensor.
"""
# Extract kwargs.
cache = kwargs.pop('cache', None)
plates = kwargs.pop('plates', '')
backend = kwargs.pop('backend', 'torch')
modulo_total = kwargs.pop('modulo_total', False)
try:
Ring = BACKEND_TO_RING[backend]
except KeyError:
raise NotImplementedError('\n'.join(
['Only the following pyro backends are currently implemented:'] +
list(BACKEND_TO_RING)))
# Parse generalized einsum equation.
if '.' in equation:
raise NotImplementedError('ubsersum does not yet support ellipsis notation')
inputs, outputs = equation.split('->')
inputs = inputs.split(',')
outputs = outputs.split(',')
assert len(inputs) == len(operands)
assert all(isinstance(x, torch.Tensor) for x in operands)
if not modulo_total and any(outputs):
raise NotImplementedError('Try setting modulo_total=True and ensuring that your use case '
'allows an arbitrary scale factor on each result plate.')
if len(operands) != len(set(operands)):
operands = [x[...] for x in operands] # ensure tensors are unique
# Check sizes.
with ignore_jit_warnings():
dim_to_size = {}
for dims, term in zip(inputs, operands):
for dim, size in zip(dims, map(int, term.shape)):
old = dim_to_size.setdefault(dim, size)
if old != size:
raise ValueError(u"Dimension size mismatch at dim '{}': {} vs {}"
.format(dim, size, old))
# Construct a tensor tree shared by all outputs.
tensor_tree = OrderedDict()
plates = frozenset(plates)
for dims, term in zip(inputs, operands):
assert len(dims) == term.dim()
term._pyro_dims = dims
ordinal = plates.intersection(dims)
tensor_tree.setdefault(ordinal, []).append(term)
# Compute outputs, sharing intermediate computations.
results = []
with shared_intermediates(cache) as cache:
ring = Ring(cache, dim_to_size=dim_to_size)
for output in outputs:
sum_dims = set(output).union(*inputs) - set(plates)
term = contract_to_tensor(tensor_tree, sum_dims,
target_ordinal=plates.intersection(output),
target_dims=sum_dims.intersection(output),
ring=ring)
if term._pyro_dims != output:
term = term.permute(*map(term._pyro_dims.index, output))
term._pyro_dims = output
results.append(term)
return tuple(results)