in src/core/tensor/tensor.cc [1634:1706]
void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta,
Tensor *C) {
Tensor fakeC;
vector<Block *> read_blocks = {A.block(), B.block()};
if (beta) {
fakeC = *C;
read_blocks.push_back(C->block());
}
if (B.nDim() == 1u) {
CHECK_EQ(A.shape().size(), 2u);
TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, {
auto a = TypeCast<SType, DType>(alpha);
auto b = TypeCast<SType, DType>(beta);
Tensor &CRef = *C;
C->device()->Exec(
[a, A, b, B, CRef, fakeC](Context *ctx) mutable {
GEMV<DType, Lang>(a, A, B, b, &CRef, ctx);
},
read_blocks, {C->block()}, "GEMV");
});
} else if (B.nDim() == 2u) {
CHECK_EQ(A.shape().size(), 2u);
CHECK(!C->transpose());
TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, {
auto a = TypeCast<SType, DType>(alpha);
auto b = TypeCast<SType, DType>(beta);
Tensor &CRef = *C;
C->device()->Exec(
[a, A, b, B, CRef, fakeC](Context *ctx) mutable {
GEMM<DType, Lang>(a, A, B, b, &CRef, ctx);
},
read_blocks, {C->block()}, "GEMM");
});
} else if (B.nDim() == 3u || B.nDim() == 4u) {
CHECK_EQ(A.shape().size(), B.shape().size());
CHECK(!C->transpose());
TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, {
auto a = TypeCast<SType, DType>(alpha);
auto b = TypeCast<SType, DType>(beta);
Tensor A_tmp;
Tensor B_tmp;
if (A.transpose() || A.broadcasted()) {
A_tmp = Tensor(A.shape(), A.device(), A.data_type());
singa::Transform(A, &A_tmp);
} else {
A_tmp = A;
}
if (B.transpose() || B.broadcasted()) {
B_tmp = Tensor(B.shape(), B.device(), B.data_type());
singa::Transform(B, &B_tmp);
} else {
B_tmp = B;
}
// batch GEMM should have same batch size
CHECK_EQ(A_tmp.shape(0), B_tmp.shape(0));
if (B.nDim() == 4u) CHECK_EQ(A_tmp.shape(1), B_tmp.shape(1));
Tensor &CRef = *C;
C->device()->Exec(
[a, A_tmp, b, B_tmp, CRef, fakeC](Context *ctx) mutable {
GEMMBatched<DType, Lang>(a, A_tmp, B_tmp, b, &CRef, ctx);
},
read_blocks, {C->block()}, "GEMMBatched");
});
} else {
LOG(FATAL) << "Un-supported tensor dimentions " << A.nDim() << "d matmul "
<< B.nDim() << "d\n";
}
}