in python/tvm/topi/arm_cpu/tensor_intrin.py [0:0]
def gemm_4x4_int8_int8_int32(M, N, K, unroll, in_type):
"""
Int8 4x4 matrix multiplication and accumulation using a sequence of
umull -> uadalp -> umull2 -> uadalp instructions. This function
takes two arrays of int8 data type A[4][K] and B[4][K], and produces
a 4x4 matrix which is equal to A*B'.
The pseudo code is as follows.
.. code-block:: c
void gemm_4x4_int8_int8_int32(int8 A[4][K], int8 B[4][K], int32 C[4][4]){
for (int i = 0; i < 4; i++){
for (int j = 0; j < 4; j++){
for (int k = 0; k < K; k++){
C[i][j] += A[i][k] * B[j][k]
}
}
}
Notes:
* The tiling strategy is picked to maximize register usage.
Parameters
----------
M : int
rows of the matrix A
N : int
columns of the matrix B
K : int
columns of matrix A
unroll : bool
Unroll the loop accumulation if True
in_type : str, {'uint8', 'int8'}
Returns
-------
intrin : TensorIntrin
The ARM uint8/int8 TensorIntrin that can be used in tensorizing schedule
"""
assert in_type in ["uint8", "int8"]
A = te.placeholder((K // 16, te.var("m"), 16), dtype=in_type, name="A")
B = te.placeholder((K // 16, te.var("n"), 16), dtype=in_type, name="B")
dtype_vec = in_type + "x16"
idxm = tvm.tir.indexmod
k = te.reduce_axis((0, K), "k")
C = te.compute(
(te.var("m"), te.var("n")),
lambda x, y: te.sum(
A[k // 16, x, idxm(k, 16)].astype("int32") * B[k // 16, y, idxm(k, 16)].astype("int32"),
axis=k,
),
name="C",
)
a_buffer = tvm.tir.decl_buffer(
A.shape,
dtype=in_type,
name="a_buffer",
offset_factor=1,
strides=[te.var("sa_1"), te.var("sa_2"), 1],
)
b_buffer = tvm.tir.decl_buffer(
B.shape,
dtype=in_type,
name="b_buffer",
offset_factor=1,
strides=[te.var("sb_1"), te.var("sb_2"), 1],
)
c_buffer = tvm.tir.decl_buffer(
C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]
)
# Intrinsics used in the following algorithm
umull_intrin = "llvm.aarch64.neon.umull" if in_type == "uint8" else "llvm.aarch64.neon.smull"
uaddlp_intrin = "llvm.aarch64.neon.uaddlp" if in_type == "uint8" else "llvm.aarch64.neon.saddlp"
addp_intrin = "llvm.aarch64.neon.addp"
def uadalp(a, b):
"""Add pair and accumulate
Parameters:
----------
a: int16x8 vector
b: int16x8 vector
Returns:
--------
return a int32x4 vector
Pseudocode:
----------
a += (b0+b1, b2+b3, b4+b5, b6+b7)
"""
return a + tvm.tir.call_llvm_pure_intrin(
"int32x4", uaddlp_intrin, tvm.tir.const(1, "uint32"), b
)
def umull(a, b):
"""Multiply long (higher part)
Parameters:
----------
a: int8x16 vector
b: int8x16 vector
Returns:
--------
return a int16x8 vector
Pseudocode:
----------
c = (a0*b0, a1*b1, a2*b2, a3*b3, a4*b4, a5*b5, a6*b6, a7*b7)
"""
a_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", a)
b_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", b)
c = tvm.tir.call_llvm_pure_intrin(
"int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_high, b_high
)
return c
def umull2(a, b):
"""Multiply long (lower part)
Parameters:
----------
a: int8x16 vector
b: int8x16 vector
Returns:
--------
return a int16x8 vector
Pseudocode:
----------
c = (a8*b8, a9*b9, a10*b10, a11*b11, a12*b12, a13*b13, a14*b14, a15*b15)
"""
a_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", a)
b_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", b)
c = tvm.tir.call_llvm_pure_intrin(
"int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_low, b_low
)
return c
def addp(a, b):
"""Add two vectors in pairs
Parameters:
----------
a: int32x4 vector
b: int32x4 vector
Returns:
--------
return a int32x4 vector
Pseudocode:
----------
c = (a0+a1, a2+a3, b0+b1, b0+b3)
"""
return tvm.tir.call_llvm_pure_intrin(
"int32x4", addp_intrin, tvm.tir.const(2, "uint32"), a, b
)
def accumulation_loop(M, N, ins, acc, tile_idx):
"""Internal tile accumulation. This function
takes two arrays of int8 data type A[tile_idx][4][16] and B[tile_idx][4][16], produces
a 4x4 matrix which is equal to A*B' and accumulates into C[4][4]
The pseudo code is as follows.
.. code-block:: c
void gemm_4x4_int8_int8_int32(int8 A[tile_idx][4][K],
int8 B[tile_idx][4][K],
int32 C[4][4]){
for (int i = 0; i < 4; i++){
for (int j = 0; j < 4; j++){
for (int k = 0; k < 16; k++){
C[i][j] += A[tile_idx][i][k] * B[tile_idx][j][k]
}
}
}
Notes:
* The tiling strategy is picked to maximize register usage.
Parameters:
----------
M : int
Number of total rows of the output matrix
N : int
Number of total columns of the output matrix
ins : list of tvm.tir.buffer
Input buffers
acc : tvm.tir.ir_builder.BufferVar
Bank of register accumulators
tiled_idx : int
Index of a sub-tile of A and B in A[tile_idx][:][:] and B[tile_idx][:][:].
Please note that 0 <= tile_idx <= K//16
"""
a0 = ins[0].vload([tile_idx, 0, 0], dtype_vec)
a1 = tvm.tir.const(0, "int8x16")
if M > 1:
a1 = ins[0].vload([tile_idx, 1, 0], dtype_vec)
a2 = tvm.tir.const(0, "int8x16")
if M > 2:
a2 = ins[0].vload([tile_idx, 2, 0], dtype_vec)
a3 = tvm.tir.const(0, "int8x16")
if M > 3:
a3 = ins[0].vload([tile_idx, 3, 0], dtype_vec)
b0 = ins[1].vload([tile_idx, 0, 0], dtype_vec)
b1 = tvm.tir.const(0, "int8x16")
if N > 1:
b1 = ins[1].vload([tile_idx, 1, 0], dtype_vec)
b2 = tvm.tir.const(0, "int8x16")
if N > 2:
b2 = ins[1].vload([tile_idx, 2, 0], dtype_vec)
b3 = tvm.tir.const(0, "int8x16")
if N > 3:
b3 = ins[1].vload([tile_idx, 3, 0], dtype_vec)
# First half
# Lower part of a0 * {b0,b1,b2,b3}
d00 = umull(a0, b0)
d01 = umull(a0, b1)
d02 = umull(a0, b2)
d03 = umull(a0, b3)
# Lower part of a1 * {b0,b1,b2,b3}
d10 = umull(a1, b0)
d11 = umull(a1, b1)
d12 = umull(a1, b2)
d13 = umull(a1, b3)
# Accumulate
acc[0] = uadalp(acc[0], d00)
acc[1] = uadalp(acc[1], d01)
acc[2] = uadalp(acc[2], d02)
acc[3] = uadalp(acc[3], d03)
acc[4] = uadalp(acc[4], d10)
acc[5] = uadalp(acc[5], d11)
acc[6] = uadalp(acc[6], d12)
acc[7] = uadalp(acc[7], d13)
# Higher part of a0 * {b0,b1,b2,b3}
d00 = umull2(a0, b0)
d01 = umull2(a0, b1)
d02 = umull2(a0, b2)
d03 = umull2(a0, b3)
# Higher part of a1 * {b0,b1,b2,b3}
d10 = umull2(a1, b0)
d11 = umull2(a1, b1)
d12 = umull2(a1, b2)
d13 = umull2(a1, b3)
# Accumulate again
acc[0] = uadalp(acc[0], d00)
acc[1] = uadalp(acc[1], d01)
acc[2] = uadalp(acc[2], d02)
acc[3] = uadalp(acc[3], d03)
acc[4] = uadalp(acc[4], d10)
acc[5] = uadalp(acc[5], d11)
acc[6] = uadalp(acc[6], d12)
acc[7] = uadalp(acc[7], d13)
# Second half
# Lower part of a2 * {b0,b1,b2,b3}
d00 = umull(a2, b0)
d01 = umull(a2, b1)
d02 = umull(a2, b2)
d03 = umull(a2, b3)
# Lower part of a3 * {b0,b1,b2,b3}
d10 = umull(a3, b0)
d11 = umull(a3, b1)
d12 = umull(a3, b2)
d13 = umull(a3, b3)
# Accumulate
acc[8] = uadalp(acc[8], d00)
acc[9] = uadalp(acc[9], d01)
acc[10] = uadalp(acc[10], d02)
acc[11] = uadalp(acc[11], d03)
acc[12] = uadalp(acc[12], d10)
acc[13] = uadalp(acc[13], d11)
acc[14] = uadalp(acc[14], d12)
acc[15] = uadalp(acc[15], d13)
# Higher part of a2 * {b0,b1,b2,b3}
d00 = umull2(a2, b0)
d01 = umull2(a2, b1)
d02 = umull2(a2, b2)
d03 = umull2(a2, b3)
# Lower part of a3 * {b0,b1,b2,b3}
d10 = umull2(a3, b0)
d11 = umull2(a3, b1)
d12 = umull2(a3, b2)
d13 = umull2(a3, b3)
# Accumulate
acc[8] = uadalp(acc[8], d00)
acc[9] = uadalp(acc[9], d01)
acc[10] = uadalp(acc[10], d02)
acc[11] = uadalp(acc[11], d03)
acc[12] = uadalp(acc[12], d10)
acc[13] = uadalp(acc[13], d11)
acc[14] = uadalp(acc[14], d12)
acc[15] = uadalp(acc[15], d13)
def _intrin_func(ins, outs):
def _instr():
ib = tvm.tir.ir_builder.create()
# Allocate a local buffer (possibly translates to registers)
acc = ib.allocate("int32x4", 16, name="accs", scope="local")
m = outs[0].shape[0]
n = outs[0].shape[1]
# Initialization
for i in range(0, 16):
acc[i] = tvm.tir.const(0, "int32x4")
if unroll:
for i in range(0, int(K // 16)):
accumulation_loop(M, N, ins, acc, i)
else:
with ib.for_range(0, K // 16, name="i") as i:
accumulation_loop(M, N, ins, acc, i)
# Final accumulations
# acc[4*r + c] contains the partial accumulations of element C[r][c]
#
# In particular:
# acc[4*r] contains the partial sums of a[r,0:K].*b[0,0:K] -> (a,b,c,d)
# acc[4*r+1] contains the partial sums of a[r, 0:K].*b[1,0:K] -> (e,f,g,h)
# acc[4*r+2] contains the partial sums of a[r, 0:K].*b[2,0:K] -> (i,j,k,l)
# acc[4*r+3] contains the partial sums of a[r, 0:K].*b[3,0:K] -> (m,n,o,p)
#
# Please note that 0<= r, c < 4
acc[0] = addp(acc[0], acc[1]) # (a+b, c+d, e+f, g+h)
acc[1] = addp(acc[2], acc[3]) # (i+j, k+l, m+n, o+p)
acc[0] = addp(acc[0], acc[1]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
acc[4] = addp(acc[4], acc[5]) # (a+b, c+d, e+f, g+h)
acc[5] = addp(acc[6], acc[7]) # (i+j, k+l, m+n, o+p)
acc[4] = addp(acc[4], acc[5]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
acc[8] = addp(acc[8], acc[9]) # (a+b, c+d, e+f, g+h)
acc[9] = addp(acc[10], acc[11]) # (i+j, k+l, m+n, o+p)
acc[8] = addp(acc[8], acc[9]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
acc[12] = addp(acc[12], acc[13]) # (a+b, c+d, e+f, g+h)
acc[13] = addp(acc[14], acc[15]) # (i+j, k+l, m+n, o+p)
acc[12] = addp(acc[12], acc[13]) # (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
# Store the result
if N > 3:
out_0 = acc[0]
out_1 = acc[4]
out_2 = acc[8]
out_3 = acc[12]
elif N > 2:
out_0 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[0])
out_1 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[4])
out_2 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[8])
out_3 = tvm.tir.call_intrin("int32x3", "tir.reinterpret", acc[12])
elif N > 1:
out_0 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[0])
out_1 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[4])
out_2 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[8])
out_3 = tvm.tir.call_intrin("int32x2", "tir.reinterpret", acc[12])
else:
out_0 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[0])
out_1 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[4])
out_2 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[8])
out_3 = tvm.tir.call_intrin("int32", "tir.reinterpret", acc[12])
ib.emit(outs[0].vstore([0, 0], out_0))
if M > 1:
ib.emit(outs[0].vstore([1, 0], out_1))
if M > 2:
ib.emit(outs[0].vstore([2, 0], out_2))
if M > 3:
ib.emit(outs[0].vstore([3, 0], out_3))
return ib.get()
# body, reset, update
return _instr()
buffer_params = {"offset_factor": 1}
return te.decl_tensor_intrin(
C.op,
_intrin_func,
binds={A: a_buffer, B: b_buffer, C: c_buffer},
default_buffer_params=buffer_params,
)