void Mult()

in src/core/tensor/tensor.cc [1634:1706]


void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta,
          Tensor *C) {
  Tensor fakeC;
  vector<Block *> read_blocks = {A.block(), B.block()};
  if (beta) {
    fakeC = *C;
    read_blocks.push_back(C->block());
  }
  if (B.nDim() == 1u) {
    CHECK_EQ(A.shape().size(), 2u);
    TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, {
      auto a = TypeCast<SType, DType>(alpha);
      auto b = TypeCast<SType, DType>(beta);
      Tensor &CRef = *C;
      C->device()->Exec(
          [a, A, b, B, CRef, fakeC](Context *ctx) mutable {
            GEMV<DType, Lang>(a, A, B, b, &CRef, ctx);
          },
          read_blocks, {C->block()}, "GEMV");
    });
  } else if (B.nDim() == 2u) {
    CHECK_EQ(A.shape().size(), 2u);
    CHECK(!C->transpose());
    TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, {
      auto a = TypeCast<SType, DType>(alpha);
      auto b = TypeCast<SType, DType>(beta);
      Tensor &CRef = *C;
      C->device()->Exec(
          [a, A, b, B, CRef, fakeC](Context *ctx) mutable {
            GEMM<DType, Lang>(a, A, B, b, &CRef, ctx);
          },
          read_blocks, {C->block()}, "GEMM");
    });
  } else if (B.nDim() == 3u || B.nDim() == 4u) {
    CHECK_EQ(A.shape().size(), B.shape().size());
    CHECK(!C->transpose());
    TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, {
      auto a = TypeCast<SType, DType>(alpha);
      auto b = TypeCast<SType, DType>(beta);

      Tensor A_tmp;
      Tensor B_tmp;

      if (A.transpose() || A.broadcasted()) {
        A_tmp = Tensor(A.shape(), A.device(), A.data_type());
        singa::Transform(A, &A_tmp);
      } else {
        A_tmp = A;
      }

      if (B.transpose() || B.broadcasted()) {
        B_tmp = Tensor(B.shape(), B.device(), B.data_type());
        singa::Transform(B, &B_tmp);
      } else {
        B_tmp = B;
      }

      // batch GEMM should have same batch size
      CHECK_EQ(A_tmp.shape(0), B_tmp.shape(0));
      if (B.nDim() == 4u) CHECK_EQ(A_tmp.shape(1), B_tmp.shape(1));

      Tensor &CRef = *C;
      C->device()->Exec(
          [a, A_tmp, b, B_tmp, CRef, fakeC](Context *ctx) mutable {
            GEMMBatched<DType, Lang>(a, A_tmp, B_tmp, b, &CRef, ctx);
          },
          read_blocks, {C->block()}, "GEMMBatched");
    });
  } else {
    LOG(FATAL) << "Un-supported tensor dimentions " << A.nDim() << "d matmul "
               << B.nDim() << "d\n";
  }
}