def einsum()

in python/singa/tensor.py [0:0]


def einsum(ops, *args):
    ''' function TODO list to finish the function in cpp(just like numpy function):
    1.sum(A,axis = None)
    2.repeat(A,repeats)
    3.transpose(A,axes = None)
    Do the matrix to matrix einsum calculation according to the operands
    Warning : this function could only support two matrix' einsum calcultion

    Args:
        ops(string): the string specifies the subscripts for summation such as
            'ki,kj->kij' Here all the 26 lowercase letter can be used here.
        args(list of array_like): These are the tensors for the operation,
            but here only support two tensors.

    Returns:
        Singa.Tensor the output matirx of the einsum calculation

    The best way to understand this function is to try the examples below:
    A_ = [0,1,2,3,4,5,6,7,8,9,10,11]
    A = A_.reshape(4,3)
    B = A_.reshape(3,4)

    Here this einsum calculation is the same as normal 'mult'
    Res = einsum('ij,jk->ik',A,B)

    >>> [[ 20  23  26  29]
         [ 56  68  80  92]
         [ 92 113 134 155]
         [128 158 188 218]]

    A_ = [0,1,2,3,4,5,6,7,8,9,10,11]
    A = A_.reshape(4,3)
    B = A_.reshape(4,3)

    Here the einsum calculation is the same as normol 'eltwise_mult'
    Res = einsum('ki,ki->ki',A,B)

    >>> [[  0   1   4]
         [  9  16  25]
         [ 36  49  64]
         [ 81 100 121]]

    A = [0,1,2,3,4,5,6,7,8,9,10,11]
    A = A.reshape(4,3)

    Res = einsum('ki,kj->kij',A,A)
    >>> [[[  0   0   0]
          [  0   1   2]
          [  0   2   4]]
         [[  9  12  15]
          [ 12  16  20]
          [ 15  20  25]]
         [[ 36  42  48]
          [ 42  49  56]
          [ 48  56  64]]
         [[ 81  90  99]
          [ 90 100 110]
          [ 99 110 121]]]

    A_ = [0,1,2,3,4,5,6,7,8,9,10,11]
    A = A_.reshape(3,2,2)

    Res = einsum('kia,kja->kij',A,A)
    >>> [[[  1   3]
          [  3  13]]
         [[ 41  59]
          [ 59  85]]
         [[145 179]
          [179 221]]]
    '''

    if len(ops) == 0:
        raise ValueError("No input operands")

    if len(args) != 2:
        raise ValueError("Currently only two operands are supported")
    # to get the input and output ops
    inputops, outputops = ops.split('->')
    inputops = inputops.split(',')

    # to get the two input tensor
    A = args[0]
    B = args[1]

    if A.ndim() != len(inputops[0]) or B.ndim() != len(inputops[1]):
        raise ValueError("input dim doesn't match operands")

    # to get the indices in input but not in output
    sums = sorted(list((set(inputops[0]) | set(inputops[1])) - set(outputops)))

    # to get the indices that A and B use to broadcast to each other
    broadcast_A = sorted(list(set(inputops[1]) - set(inputops[0])))
    broadcast_B = sorted(list(set(inputops[0]) - set(inputops[1])))
    # to get all the indices in input
    outputall = sorted(list(set(inputops[0]) | set(inputops[1])))

    # Map indices to axis integers
    sums = [outputall.index(x) for x in sums]
    broadcast_idA = [inputops[1].find(x) for x in broadcast_A]
    broadcast_idB = [inputops[0].find(x) for x in broadcast_B]

    broadcast_a = [B.shape[x] for x in broadcast_idA]
    broadcast_b = [A.shape[x] for x in broadcast_idB]

    # get the the transpose and reshape parameter used in the elementwise
    # calculation
    transpose_A = [(list(inputops[0]) + broadcast_A).index(x) for x in outputall
                  ]
    transpose_B = [(list(inputops[1]) + broadcast_B).index(x) for x in outputall
                  ]

    reshape_A = list(A.shape) + broadcast_a
    reshape_B = list(B.shape) + broadcast_b

    if len(broadcast_a) == 0:
        broadcast_a = [1]
    if len(broadcast_b) == 0:
        broadcast_b = [1]
    mult_A = repeat(A, product(broadcast_a))
    mult_A = mult_A.reshape(reshape_A)
    mult_A = transpose(mult_A, transpose_A)
    mult_B = repeat(B, product(broadcast_b))
    mult_B = mult_B.reshape(reshape_B)
    mult_B = transpose(mult_B, transpose_B)

    if mult_A.shape != mult_B.shape:
        raise ValueError("Error: matrix dimension mismatch")
    res = eltwise_mult(mult_A, mult_B)
    sum_R = sorted(sums, reverse=True)
    for i in sum_R:
        res = sum(res, axis=i)
    transpose_res = [sorted(list(outputops)).index(x) for x in list(outputops)]
    res = transpose(res, transpose_res)

    return res