def tensordot()

in python/singa/tensor.py [0:0]


def tensordot(A, B, axes=2):
    """Returns the tensor multiplication of two tensors along specified axes.

    This is equivalent to compute dot product along the specified axes which
    are treated as one axis by reshaping.

    Args:
        A: Singa.Tensor
        B: Singa.Tensor
        axes:
            - If it is an integer, then ''axes'' represent axes at the last of ''a`'' and
              the first of ''b'' are used.
            - If it is a pair of sequences of integers, then these two
              sequences specify the list of axes for ''a'' and ''b''. The
              corresponding axes are paired for sum-product.

    Returns:
        singa.tensor: The tensor  product of ''A'' and ''B'' along the
        axes specified by ''axes''.

    Thanks to numpy.tensordot.
    the link is https://github.com/numpy/numpy/blob/v1.14.0/numpy/core/numeric.py#L1123-L1306
    """
    # when axes is an integer, axes_A and axes_B represent axes at the last of ''A'' and
    # the first of ''B''. For example, when axes is 1, we do the normal multiplication :
    # if A is in shape(3,2,4), B is in shape(4,2,5), it will return a matrix in shape(3,2,2,5)
    # when axes is 2 and A,B are shape (3,2,4) and (2,4,5), it will return a
    # matrix in shape(3,5)

    if type(axes) == int:
        axes_A = list(range(-axes, 0))
        axes_B = list(range(0, axes))
    else:
        axes_A, axes_B = axes
    # when axes is a pair of sequences of integers.For example, A is in shape(3,2,4),
    # B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a
    # matrix in shape(3,5)
    if isinstance(axes_A, list):
        na = len(axes_A)
        axes_A = list(axes_A)
    else:
        axes_A = [axes_A]
        na = 1
    if isinstance(axes_B, list):
        nb = len(axes_B)
        axes_B = list(axes_B)
    else:
        axes_B = [axes_B]
        nb = 1

    # a_shape and b_shape are the shape of tensor A and B, while nda and ndb
    # are the dim of A and B
    a_shape = A.shape
    nda = A.ndim()
    b_shape = B.shape
    ndb = B.ndim()
    equal = True
    # to check if the length of axe_A is equal to axes_B
    if na != nb:
        equal = False
    else:
        # to make the shape match
        for k in range(na):
            if a_shape[axes_A[k]] != b_shape[axes_B[k]]:
                equal = False
                break
            if axes_A[k] < 0:
                axes_A[k] += nda
            if axes_B[k] < 0:
                axes_B[k] += ndb
    if not equal:
        raise ValueError("shape-mismatch for sum")
    '''start to do the calculation according to the axes'''

    notin = [k for k in range(nda) if k not in axes_A]
    # nda is the dim of A, and axes_a is the axis for A, notin is the axis
    # which is not in axes_A
    newaxes_a = notin + axes_A
    N2 = 1
    for axis in axes_A:
        N2 *= a_shape[axis]
    N1 = 1
    for ax in notin:
        N1 *= a_shape[ax]
    # newshape_a is the shape to do multiplication.For example, A is in shape(3,2,4),
    # B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a should be (3,5)
    # olda is the shape that will be shown in the result.
    newshape_a = (N1, N2)
    olda = [a_shape[axis] for axis in notin]
    notin = [k for k in range(ndb) if k not in axes_B]
    newaxes_b = axes_B + notin
    N2 = 1
    for axis in axes_B:
        N2 *= b_shape[axis]
    N1 = 1
    for bx in notin:
        N1 *= b_shape[bx]
    newshape_b = (N2, N1)
    oldb = [b_shape[axis] for axis in notin]

    A = transpose(A, newaxes_a)
    B = transpose(B, newaxes_b)
    at = reshape(A, newshape_a)
    bt = reshape(B, newshape_b)

    res = mult(at, bt)
    if len(olda + oldb) == 0:
        olda = [1]
        oldb = [1]
        res = res.reshape(tuple(olda + oldb))
    else:
        res = res.reshape(tuple(olda + oldb))

    return res