in recipes/local_prior_match/src/runtime/Utils.cpp [148:181]
fl::Variable adjustProb(
const fl::Variable& logprob,
const std::vector<int>& hypoNums,
bool renormalize,
bool linear) {
if (!renormalize && !linear) {
return logprob;
}
std::vector<fl::Variable> outputVec;
int offset = 0;
for (auto& hypoNum : hypoNums) {
if (hypoNum > 0) {
auto logprobSlice = logprob(af::seq(offset, offset + hypoNum - 1));
if (renormalize && linear) {
outputVec.emplace_back(fl::softmax(logprobSlice, 0));
} else if (renormalize && !linear) {
outputVec.emplace_back(fl::logSoftmax(logprobSlice, 0));
} else if (!renormalize && linear) {
outputVec.emplace_back(fl::exp(logprobSlice));
} else {
throw std::runtime_error(
"Something is really wrong. Should never arrive here.");
}
offset += hypoNum;
}
}
if (offset != logprob.dims()[0]) {
throw std::runtime_error(
"Total number of hypos inconsistent : " + std::to_string(offset) +
" vs " + std::to_string(logprob.dims()[0]));
}
return concatenate(outputVec, 0);
}