in pyro/ops/contract.py [0:0]
def naive_ubersum(equation, *operands, **kwargs):
"""
Naive reference implementation of :func:`ubersum` via unrolling.
This implementation should never raise ``NotImplementedError``.
This implementation should agree with :func:`ubersum` whenver
:func:`ubersum` does not raise ``NotImplementedError``.
"""
# Parse equation, without loss of generality assuming a single output.
inputs, outputs = equation.split('->')
outputs = outputs.split(',')
if len(outputs) > 1:
return tuple(naive_ubersum(inputs + '->' + output, *operands, **kwargs)[0]
for output in outputs)
output, = outputs
inputs = inputs.split(',')
backend = kwargs.pop('backend', 'pyro.ops.einsum.torch_log')
# Split dims into plate dims, contraction dims, and dims to keep.
plates = set(kwargs.pop('plates', ''))
if not plates:
result = opt_einsum.contract(equation, *operands, backend=backend)
return (result,)
output_dims = set(output)
# Collect sizes of all dimensions.
sizes = {}
for input_, operand in zip(inputs, operands):
for dim, size in zip(input_, operand.shape):
old = sizes.setdefault(dim, size)
if old != size:
raise ValueError(u"Dimension size mismatch at dim '{}': {} vs {}"
.format(dim, size, old))
# Compute plate context for each non-plate dim, by convention the
# intersection over all plate contexts of tensors in which the dim appears.
dim_to_ordinal = {}
for dims in map(set, inputs):
ordinal = dims & plates
for dim in dims - plates:
dim_to_ordinal[dim] = dim_to_ordinal.get(dim, ordinal) & ordinal
for dim in output_dims - plates:
_check_plates_are_sensible({dim}, dim_to_ordinal[dim] - output_dims)
# Unroll by replicating along plate dimensions.
unroll_dim = _DimUnroller(dim_to_ordinal)
flat_inputs = []
flat_operands = []
for input_, operand in zip(inputs, operands):
local_dims = [d for d in input_ if d in plates]
offsets = [input_.index(d) - len(input_) for d in local_dims]
for index in itertools.product(*(range(sizes[d]) for d in local_dims)):
flat_inputs.append(''.join(unroll_dim(d, dict(zip(local_dims, index)))
for d in input_ if d not in plates))
flat_operands.append(_select(operand, offsets, index))
# Defer to unplated einsum.
result = torch.empty(torch.Size(sizes[d] for d in output),
dtype=operands[0].dtype, device=operands[0].device)
local_dims = [d for d in output if d in plates]
offsets = [output.index(d) - len(output) for d in local_dims]
for index in itertools.product(*(range(sizes[d]) for d in local_dims)):
flat_output = ''.join(unroll_dim(d, dict(zip(local_dims, index)))
for d in output if d not in plates)
flat_equation = ','.join(flat_inputs) + '->' + flat_output
flat_result = opt_einsum.contract(flat_equation, *flat_operands, backend=backend)
if not local_dims:
result = flat_result
break
_select(result, offsets, index).copy_(flat_result)
return (result,)