def einsum()

in pyro/ops/einsum/torch_log.py [0:0]


def einsum(equation, *operands):
    """
    Log-sum-exp implementation of einsum.
    """
    # rename symbols to support PyTorch 0.4.1 and earlier,
    # which allow only symbols a-z.
    symbols = sorted(set(equation) - set(',->'))
    rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz'))
    equation = ''.join(rename.get(s, s) for s in equation)

    inputs, output = equation.split('->')
    if inputs == output:
        return operands[0][...]  # create a new object
    inputs = inputs.split(',')

    shifts = []
    exp_operands = []
    for dims, operand in zip(inputs, operands):
        shift = operand
        for i, dim in enumerate(dims):
            if dim not in output:
                shift = shift.max(i, keepdim=True)[0]
        # avoid nan due to -inf - -inf
        shift = shift.clamp(min=torch.finfo(shift.dtype).min)
        exp_operands.append((operand - shift).exp())

        # permute shift to match output
        shift = shift.reshape(torch.Size(size for size, dim in zip(operand.shape, dims)
                                         if dim in output))
        if shift.dim():
            shift = shift.reshape((1,) * (len(output) - shift.dim()) + shift.shape)
            dims = [dim for dim in dims if dim in output]
            dims = [dim for dim in output if dim not in dims] + dims
            shift = shift.permute(*(dims.index(dim) for dim in output))
        shifts.append(shift)

    result = safe_log(torch.einsum(equation, exp_operands))
    return sum(shifts + [result])