def _type_of_matrix()

in src/beanmachine/ppl/compiler/bmg_types.py [0:0]


def _type_of_matrix(v: torch.Tensor) -> BMGLatticeType:
    elements = v.numel()

    # If we have tensor([]) then that is not useful as a value
    # or a matrix; just call it a tensor.
    if elements == 0:
        return Tensor

    # If we have a single element tensor no matter what its dimensionality,
    # treat it as a single value.

    if elements == 1:
        return type_of_value(float(v))  # pyre-fixme

    # We have more than one element. What's the shape?

    shape = v.shape
    dimensions = len(shape)

    # If we have more than two dimensions then we cannot make it a matrix.
    # CONSIDER: Suppose we have something like [[[10, 20]]]] which is 1 x 1 x 2.
    # We could reduce that to a 1 x 2 matrix if we needed to. We might discard
    # sizes on the right equal to one.

    # We have the rows and columns of the original tensor, which is row-major.
    # But in BMG, constant matrices are expressed in column-major form.
    # Therefore we swap rows and columns here.

    if dimensions > 2:
        return Tensor
    tensor_rows, tensor_cols = _size_to_rc(shape)

    # However, for the purposes of analysis below, we still do it row by
    # row because that is more convenient when working with tensors:
    v = v.view(tensor_rows, tensor_cols)

    c = tensor_rows
    r = tensor_cols

    # We've got the shape. What is the smallest type
    # that is greater than or equal to the smallest type of
    # all the elements?

    sup = supremum(*[type_of_value(element) for row in v for element in row])

    # We should get a 1x1 matrix out; there should be no way to get
    # top or bottom out.

    assert isinstance(sup, BMGMatrixType)
    assert sup.rows == 1
    assert sup.columns == 1

    if sup in {Real, PositiveReal, NegativeReal, Natural}:
        return sup.with_dimensions(r, c)

    # The only remaining possibilities are:
    #
    # * Every element was 0 -- sup is Zero
    # * Every element was 1 -- sup is One
    # * Every element was 0 or 1 -- sup is Boolean
    # * At least one element was between 0 and 1 -- sup is Probability
    #
    # In the first two cases, we might have a one-hot.
    # In the third case, it is possible that we have a simplex.

    assert sup in {Boolean, Zero, One, Probability}

    sums_to_one = all(abs(float(row.sum()) - 1.0) <= simplex_precision for row in v)
    if sums_to_one:
        if sup == Probability:
            return SimplexMatrix(r, c)
        return OneHotMatrix(r, c)

    # It is not a simplex or a one-hot. Is it a matrix of probabilities that
    # do not sum to one?

    if sup == Probability:
        return sup.with_dimensions(r, c)

    # The only remaining possibilities are all zeros, all ones,
    # or some mixture of zero and one.
    #
    # If we have all zeros then this could be treated as either a matrix
    # of Booleans or a matrix of negative reals, and we do not know which
    # we will need; matrix of zeros is the type smaller than both those,
    # so return it:

    if sup == Zero:
        return sup.with_dimensions(r, c)

    # The only remaining possibility is matrix of all ones, or matrix
    # of some zeros and some ones. Either way, the smallest type
    # left is matrix of Booleans.

    return BooleanMatrix(r, c)