in python/singa/tensor.py [0:0]
def einsum(ops, *args):
''' function TODO list to finish the function in cpp(just like numpy function):
1.sum(A,axis = None)
2.repeat(A,repeats)
3.transpose(A,axes = None)
Do the matrix to matrix einsum calculation according to the operands
Warning : this function could only support two matrix' einsum calcultion
Args:
ops(string): the string specifies the subscripts for summation such as
'ki,kj->kij' Here all the 26 lowercase letter can be used here.
args(list of array_like): These are the tensors for the operation,
but here only support two tensors.
Returns:
Singa.Tensor the output matirx of the einsum calculation
The best way to understand this function is to try the examples below:
A_ = [0,1,2,3,4,5,6,7,8,9,10,11]
A = A_.reshape(4,3)
B = A_.reshape(3,4)
Here this einsum calculation is the same as normal 'mult'
Res = einsum('ij,jk->ik',A,B)
>>> [[ 20 23 26 29]
[ 56 68 80 92]
[ 92 113 134 155]
[128 158 188 218]]
A_ = [0,1,2,3,4,5,6,7,8,9,10,11]
A = A_.reshape(4,3)
B = A_.reshape(4,3)
Here the einsum calculation is the same as normol 'eltwise_mult'
Res = einsum('ki,ki->ki',A,B)
>>> [[ 0 1 4]
[ 9 16 25]
[ 36 49 64]
[ 81 100 121]]
A = [0,1,2,3,4,5,6,7,8,9,10,11]
A = A.reshape(4,3)
Res = einsum('ki,kj->kij',A,A)
>>> [[[ 0 0 0]
[ 0 1 2]
[ 0 2 4]]
[[ 9 12 15]
[ 12 16 20]
[ 15 20 25]]
[[ 36 42 48]
[ 42 49 56]
[ 48 56 64]]
[[ 81 90 99]
[ 90 100 110]
[ 99 110 121]]]
A_ = [0,1,2,3,4,5,6,7,8,9,10,11]
A = A_.reshape(3,2,2)
Res = einsum('kia,kja->kij',A,A)
>>> [[[ 1 3]
[ 3 13]]
[[ 41 59]
[ 59 85]]
[[145 179]
[179 221]]]
'''
if len(ops) == 0:
raise ValueError("No input operands")
if len(args) != 2:
raise ValueError("Currently only two operands are supported")
# to get the input and output ops
inputops, outputops = ops.split('->')
inputops = inputops.split(',')
# to get the two input tensor
A = args[0]
B = args[1]
if A.ndim() != len(inputops[0]) or B.ndim() != len(inputops[1]):
raise ValueError("input dim doesn't match operands")
# to get the indices in input but not in output
sums = sorted(list((set(inputops[0]) | set(inputops[1])) - set(outputops)))
# to get the indices that A and B use to broadcast to each other
broadcast_A = sorted(list(set(inputops[1]) - set(inputops[0])))
broadcast_B = sorted(list(set(inputops[0]) - set(inputops[1])))
# to get all the indices in input
outputall = sorted(list(set(inputops[0]) | set(inputops[1])))
# Map indices to axis integers
sums = [outputall.index(x) for x in sums]
broadcast_idA = [inputops[1].find(x) for x in broadcast_A]
broadcast_idB = [inputops[0].find(x) for x in broadcast_B]
broadcast_a = [B.shape[x] for x in broadcast_idA]
broadcast_b = [A.shape[x] for x in broadcast_idB]
# get the the transpose and reshape parameter used in the elementwise
# calculation
transpose_A = [(list(inputops[0]) + broadcast_A).index(x) for x in outputall
]
transpose_B = [(list(inputops[1]) + broadcast_B).index(x) for x in outputall
]
reshape_A = list(A.shape) + broadcast_a
reshape_B = list(B.shape) + broadcast_b
if len(broadcast_a) == 0:
broadcast_a = [1]
if len(broadcast_b) == 0:
broadcast_b = [1]
mult_A = repeat(A, product(broadcast_a))
mult_A = mult_A.reshape(reshape_A)
mult_A = transpose(mult_A, transpose_A)
mult_B = repeat(B, product(broadcast_b))
mult_B = mult_B.reshape(reshape_B)
mult_B = transpose(mult_B, transpose_B)
if mult_A.shape != mult_B.shape:
raise ValueError("Error: matrix dimension mismatch")
res = eltwise_mult(mult_A, mult_B)
sum_R = sorted(sums, reverse=True)
for i in sum_R:
res = sum(res, axis=i)
transpose_res = [sorted(list(outputops)).index(x) for x in list(outputops)]
res = transpose(res, transpose_res)
return res