def naive_ubersum()

in pyro/ops/contract.py [0:0]


def naive_ubersum(equation, *operands, **kwargs):
    """
    Naive reference implementation of :func:`ubersum` via unrolling.

    This implementation should never raise ``NotImplementedError``.
    This implementation should agree with :func:`ubersum` whenver
    :func:`ubersum` does not raise ``NotImplementedError``.
    """
    # Parse equation, without loss of generality assuming a single output.
    inputs, outputs = equation.split('->')
    outputs = outputs.split(',')
    if len(outputs) > 1:
        return tuple(naive_ubersum(inputs + '->' + output, *operands, **kwargs)[0]
                     for output in outputs)
    output, = outputs
    inputs = inputs.split(',')
    backend = kwargs.pop('backend', 'pyro.ops.einsum.torch_log')

    # Split dims into plate dims, contraction dims, and dims to keep.
    plates = set(kwargs.pop('plates', ''))
    if not plates:
        result = opt_einsum.contract(equation, *operands, backend=backend)
        return (result,)
    output_dims = set(output)

    # Collect sizes of all dimensions.
    sizes = {}
    for input_, operand in zip(inputs, operands):
        for dim, size in zip(input_, operand.shape):
            old = sizes.setdefault(dim, size)
            if old != size:
                raise ValueError(u"Dimension size mismatch at dim '{}': {} vs {}"
                                 .format(dim, size, old))

    # Compute plate context for each non-plate dim, by convention the
    # intersection over all plate contexts of tensors in which the dim appears.
    dim_to_ordinal = {}
    for dims in map(set, inputs):
        ordinal = dims & plates
        for dim in dims - plates:
            dim_to_ordinal[dim] = dim_to_ordinal.get(dim, ordinal) & ordinal
    for dim in output_dims - plates:
        _check_plates_are_sensible({dim}, dim_to_ordinal[dim] - output_dims)

    # Unroll by replicating along plate dimensions.
    unroll_dim = _DimUnroller(dim_to_ordinal)
    flat_inputs = []
    flat_operands = []
    for input_, operand in zip(inputs, operands):
        local_dims = [d for d in input_ if d in plates]
        offsets = [input_.index(d) - len(input_) for d in local_dims]
        for index in itertools.product(*(range(sizes[d]) for d in local_dims)):
            flat_inputs.append(''.join(unroll_dim(d, dict(zip(local_dims, index)))
                                       for d in input_ if d not in plates))
            flat_operands.append(_select(operand, offsets, index))

    # Defer to unplated einsum.
    result = torch.empty(torch.Size(sizes[d] for d in output),
                         dtype=operands[0].dtype, device=operands[0].device)
    local_dims = [d for d in output if d in plates]
    offsets = [output.index(d) - len(output) for d in local_dims]
    for index in itertools.product(*(range(sizes[d]) for d in local_dims)):
        flat_output = ''.join(unroll_dim(d, dict(zip(local_dims, index)))
                              for d in output if d not in plates)
        flat_equation = ','.join(flat_inputs) + '->' + flat_output
        flat_result = opt_einsum.contract(flat_equation, *flat_operands, backend=backend)
        if not local_dims:
            result = flat_result
            break
        _select(result, offsets, index).copy_(flat_result)
    return (result,)