fl::Variable adjustProb()

in recipes/local_prior_match/src/runtime/Utils.cpp [148:181]


fl::Variable adjustProb(
    const fl::Variable& logprob,
    const std::vector<int>& hypoNums,
    bool renormalize,
    bool linear) {
  if (!renormalize && !linear) {
    return logprob;
  }

  std::vector<fl::Variable> outputVec;
  int offset = 0;
  for (auto& hypoNum : hypoNums) {
    if (hypoNum > 0) {
      auto logprobSlice = logprob(af::seq(offset, offset + hypoNum - 1));
      if (renormalize && linear) {
        outputVec.emplace_back(fl::softmax(logprobSlice, 0));
      } else if (renormalize && !linear) {
        outputVec.emplace_back(fl::logSoftmax(logprobSlice, 0));
      } else if (!renormalize && linear) {
        outputVec.emplace_back(fl::exp(logprobSlice));
      } else {
        throw std::runtime_error(
            "Something is really wrong. Should never arrive here.");
      }
      offset += hypoNum;
    }
  }
  if (offset != logprob.dims()[0]) {
    throw std::runtime_error(
        "Total number of hypos inconsistent : " + std::to_string(offset) +
        " vs " + std::to_string(logprob.dims()[0]));
  }
  return concatenate(outputVec, 0);
}