Alignment transferThroughCharacters()

in inference/src/translator/response.cpp [13:98]


Alignment transferThroughCharacters(const std::vector<ByteRange> &sourceSidePivots,
                                    const std::vector<ByteRange> &targetSidePivots,
                                    const Alignment &pivotGivenTargets) {
  // Initialize an empty alignment matrix.
  Alignment remapped(pivotGivenTargets.size(), std::vector<float>(sourceSidePivots.size(), 0.0f));

  size_t sq, qt;
  for (sq = 0, qt = 0; sq < sourceSidePivots.size() && qt < targetSidePivots.size();
       /*each branch inside increments either sq or qt or both, therefore the loop terminates */) {
    auto &sourceSidePivot = sourceSidePivots[sq];
    auto &targetSidePivot = targetSidePivots[qt];
    if (sourceSidePivot.begin == targetSidePivot.begin && sourceSidePivot.end == targetSidePivot.end) {
      for (size_t t = 0; t < pivotGivenTargets.size(); t++) {
        remapped[t][sq] += pivotGivenTargets[t][qt];
      }

      // Perfect match, move pointer from both.
      sq++, qt++;
    } else {
      // Do we have overlap?
      size_t left = std::max(targetSidePivot.begin, sourceSidePivot.begin);
      size_t right = std::min(targetSidePivot.end, sourceSidePivot.end);

      assert(left < right);  // there should be overlap.

      size_t charCount = right - left;
      size_t probSpread = targetSidePivot.size();
      for (size_t t = 0; t < pivotGivenTargets.size(); t++) {
        remapped[t][sq] += charCount * pivotGivenTargets[t][qt] / static_cast<float>(probSpread);
      }

      // Which one is ahead? sq or qt or both end at same point?
      if (sourceSidePivot.end == targetSidePivot.end) {
        sq++;
        qt++;
      } else if (sourceSidePivot.end > targetSidePivot.end) {
        qt++;
      } else {  // sourceSidePivot.end < targetSidePivot.end
        sq++;
      }
    }
  }

  // The following is left in here for future debugging. Every token in source is expected to have been processed in the
  // above pipeline. We advance the pivot-token index based on overlap with source-token. @jerinphilip is worried about
  // EOS not existing when people try weird 4-model things in the future and would like to keep this check here.
  assert(sq == sourceSidePivots.size());

  while (qt < targetSidePivots.size()) {
    // There is a case of EOS not being predicted. In this case the two pointer algorithm will fail. The just author
    // will redistribute the surplus among subjects.

    // assert in DEBUG, that this is only EOS - occuring at the end and with zero-surface.
    assert(qt == targetSidePivots.size() - 1 && targetSidePivots[qt].size() == 0);
    for (size_t t = 0; t < pivotGivenTargets.size(); t++) {
      float gift = pivotGivenTargets[t][qt] / sourceSidePivots.size();
      for (size_t sq = 0; sq < sourceSidePivots.size(); sq++) {
        remapped[t][sq] += gift;
      }
    }

    qt++;
  }

#ifdef DEBUG
  // The following sanity check ensures when DEBUG is enabled that we have successfully transferred all probabily mass
  // available over pivot tokens given a target token in our original input to the new remapped representation.
  //
  // It's been discovered that floating point arithmetic before we get the Alignment matrix can have values such that
  // the distribution does not sum upto 1.
  const float EPS = 1e-6;
  for (size_t t = 0; t < pivotGivenTargets.size(); t++) {
    float sum = 0.0f, expectedSum = 0.0f;
    for (size_t qt = 0; qt < targetSidePivots.size(); qt++) {
      expectedSum += pivotGivenTargets[t][qt];
    }
    for (size_t sq = 0; sq < sourceSidePivots.size(); sq++) {
      sum += remapped[t][sq];
    }
    std::cerr << fmt::format("Sum @ token {} = {} to be compared with expected {}.", t, sum, expectedSum) << std::endl;
    ABORT_IF(std::abs(sum - expectedSum) > EPS, "Haven't accumulated probabilities, re-examine");
  }
#endif  // DEBUG

  return remapped;
}