std::shared_ptr LocalNorm::run()

in recipes/streaming_convnets/inference/inference/module/nn/LocalNorm.cpp [50:106]


std::shared_ptr<ModuleProcessingState> LocalNorm::run(
    std::shared_ptr<ModuleProcessingState> input) {
  assert(input);
  std::shared_ptr<ModuleProcessingState> output = input->next();
  assert(output);
  std::shared_ptr<IOBuffer> inputBuf = input->buffer(0);
  assert(inputBuf);
  const int nFeatFrames = inputBuf->size<float>() / featureSize_;
  if (nFeatFrames == 0) {
    return output;
  }

  assert(output->buffers().size() >= 3);
  std::shared_ptr<IOBuffer> outputBuf = output->buffer(0);
  assert(outputBuf);
  std::shared_ptr<IOBuffer> sumBuf = output->buffer(1);
  std::shared_ptr<IOBuffer> sqSumBuf = output->buffer(2);
  assert(sumBuf && sqSumBuf);

  const int outputSize = nFeatFrames * featureSize_;
  outputBuf->ensure<float>(outputSize);

  for (int t = 0; t < nFeatFrames; ++t) {
    const float* inPtr = inputBuf->data<float>();
    float* outPtr = outputBuf->tail<float>();

    float curSum = std::accumulate(inPtr, inPtr + featureSize_, 0.0);
    float curSqSum =
        std::inner_product(inPtr, inPtr + featureSize_, inPtr, 0.0);
    sumBuf->write(&curSum, 1);
    sqSumBuf->write(&curSqSum, 1);

    int totalSize = sumBuf->size<float>();
    assert(totalSize == sqSumBuf->size<float>());
    float totalSum = std::accumulate(
        sumBuf->data<float>(), sumBuf->data<float>() + totalSize, 0.0);
    float totalSqSum = std::accumulate(
        sqSumBuf->data<float>(), sqSumBuf->data<float>() + totalSize, 0.0);

    float mean = totalSum / (totalSize * featureSize_);
    float stddev =
        std::sqrt(totalSqSum / (totalSize * featureSize_) - mean * mean);

    if (stddev <= kEpsilon) {
      stddev = 1.0;
    }
    meanNormalize(inPtr, featureSize_, mean, stddev, 1.0, 0.0, outPtr);

    if (totalSize > leftContextSize_) {
      sumBuf->consume<float>(1);
      sqSumBuf->consume<float>(1);
    }
    inputBuf->consume<float>(featureSize_);
    outputBuf->move<float>(featureSize_);
  }
  return output;
}