std::shared_ptr LinearFbGemm::run()

in recipes/streaming_convnets/inference/inference/module/nn/backend/fbgemm/LinearFbGemm.cpp [66:103]


std::shared_ptr<ModuleProcessingState> LinearFbGemm::run(
    std::shared_ptr<ModuleProcessingState> input) {
  assert(input);
  std::shared_ptr<ModuleProcessingState> output = input->next();
  assert(output);
  assert(input->buffers().size() == 1);
  std::shared_ptr<IOBuffer> inputBuf = input->buffer(0);
  assert(inputBuf);

  int nFrames = inputBuf->size<float>() / nInput_;
  if (nFrames == 0) {
    return output;
  }
  assert(output->buffers().size() == 1);
  std::shared_ptr<IOBuffer> outputBuf = output->buffer(0);
  assert(outputBuf);

  const int outSize = nFrames * nOutput_;
  outputBuf->ensure<float>(outSize);
  auto* outPtr = outputBuf->tail<float>();
  for (int i = 0; i < nFrames; ++i) {
    std::copy_n(bias_->buffer_.data<float>(), nOutput_, outPtr + i * nOutput_);
  }

  outputBuf->move<float>(outSize);

  constexpr float beta = 1.0;
  cblas_gemm_compute(
      fbgemm::matrix_op_t::Transpose,
      nFrames,
      inputBuf->data<float>(),
      *packedWeights_,
      beta,
      outPtr);

  inputBuf->consume<float>(nFrames * nInput_);
  return output;
}