in python/singa/tensor.py [0:0]
def tensordot(A, B, axes=2):
"""Returns the tensor multiplication of two tensors along specified axes.
This is equivalent to compute dot product along the specified axes which
are treated as one axis by reshaping.
Args:
A: Singa.Tensor
B: Singa.Tensor
axes:
- If it is an integer, then ''axes'' represent axes at the last of ''a`'' and
the first of ''b'' are used.
- If it is a pair of sequences of integers, then these two
sequences specify the list of axes for ''a'' and ''b''. The
corresponding axes are paired for sum-product.
Returns:
singa.tensor: The tensor product of ''A'' and ''B'' along the
axes specified by ''axes''.
Thanks to numpy.tensordot.
the link is https://github.com/numpy/numpy/blob/v1.14.0/numpy/core/numeric.py#L1123-L1306
"""
# when axes is an integer, axes_A and axes_B represent axes at the last of ''A'' and
# the first of ''B''. For example, when axes is 1, we do the normal multiplication :
# if A is in shape(3,2,4), B is in shape(4,2,5), it will return a matrix in shape(3,2,2,5)
# when axes is 2 and A,B are shape (3,2,4) and (2,4,5), it will return a
# matrix in shape(3,5)
if type(axes) == int:
axes_A = list(range(-axes, 0))
axes_B = list(range(0, axes))
else:
axes_A, axes_B = axes
# when axes is a pair of sequences of integers.For example, A is in shape(3,2,4),
# B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a
# matrix in shape(3,5)
if isinstance(axes_A, list):
na = len(axes_A)
axes_A = list(axes_A)
else:
axes_A = [axes_A]
na = 1
if isinstance(axes_B, list):
nb = len(axes_B)
axes_B = list(axes_B)
else:
axes_B = [axes_B]
nb = 1
# a_shape and b_shape are the shape of tensor A and B, while nda and ndb
# are the dim of A and B
a_shape = A.shape
nda = A.ndim()
b_shape = B.shape
ndb = B.ndim()
equal = True
# to check if the length of axe_A is equal to axes_B
if na != nb:
equal = False
else:
# to make the shape match
for k in range(na):
if a_shape[axes_A[k]] != b_shape[axes_B[k]]:
equal = False
break
if axes_A[k] < 0:
axes_A[k] += nda
if axes_B[k] < 0:
axes_B[k] += ndb
if not equal:
raise ValueError("shape-mismatch for sum")
'''start to do the calculation according to the axes'''
notin = [k for k in range(nda) if k not in axes_A]
# nda is the dim of A, and axes_a is the axis for A, notin is the axis
# which is not in axes_A
newaxes_a = notin + axes_A
N2 = 1
for axis in axes_A:
N2 *= a_shape[axis]
N1 = 1
for ax in notin:
N1 *= a_shape[ax]
# newshape_a is the shape to do multiplication.For example, A is in shape(3,2,4),
# B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a should be (3,5)
# olda is the shape that will be shown in the result.
newshape_a = (N1, N2)
olda = [a_shape[axis] for axis in notin]
notin = [k for k in range(ndb) if k not in axes_B]
newaxes_b = axes_B + notin
N2 = 1
for axis in axes_B:
N2 *= b_shape[axis]
N1 = 1
for bx in notin:
N1 *= b_shape[bx]
newshape_b = (N2, N1)
oldb = [b_shape[axis] for axis in notin]
A = transpose(A, newaxes_a)
B = transpose(B, newaxes_b)
at = reshape(A, newshape_a)
bt = reshape(B, newshape_b)
res = mult(at, bt)
if len(olda + oldb) == 0:
olda = [1]
oldb = [1]
res = res.reshape(tuple(olda + oldb))
else:
res = res.reshape(tuple(olda + oldb))
return res