in tensorflow_compression/cc/lib/range_coder.cc [36:263]
void RangeEncoder::Encode(int32_t lower, int32_t upper, int precision,
std::string* sink) {
// Input requirement: 0 <= lower < upper <= 2^precision.
DCHECK_LE(0, lower);
DCHECK_LT(lower, upper);
DCHECK_LE(upper, 1 << precision);
DCHECK_GT(precision, 0);
DCHECK_LE(precision, 16);
// `base` and `size` represent a half-open interval [base, base + size).
// Loop invariant: 2^16 <= size <= 2^32.
//
// Note that keeping size above 2^16 is important. Since the interval sizes
// are quantized to up to 16 bits, the smallest interval size the encode may
// handle is 2^-16. If size is smaller than 2^16, a small interval input may
// collapse the encoder range into an empty interval.
const uint64_t size = static_cast<uint64_t>(size_minus1_) + 1;
DCHECK_NE(size >> 16, 0);
// For short notation, let u := lower and v := upper.
//
// The input u, v represents a half-open interval [u, v) / 2^precision.
// This narrows the current interval roughly to
// [base + (size * u) / 2^precision, base + (size * v) / 2^precision).
//
// TODO(ssjhv): Try rounding if it helps improve compression ratio, at the
// expense of more operations. In the test using Zipf distribution, the
// overhead over the theoretical compression ratio was ~0.01%.
// NOTE: The max value of `size` is 2^32 and size > 0. Therefore `size * u`
// can be rewritten as `(size - 1) * u + u` and all the computation can be
// done in 32-bit mode. If 32-bit multiply is faster, then rewrite.
const uint32_t a = (size * static_cast<uint64_t>(lower)) >> precision;
const uint32_t b = ((size * static_cast<uint64_t>(upper)) >> precision) - 1;
DCHECK_LE(a, b);
// Let's confirm the RHS of a, b fit in uint32 type.
// Recall that 0 <= u < 2^precision, and size <= 2^32. Therefore
// (size * u) / 2^precision < size <= 2^32,
// and the value of a fits in uint32 type. Similarly, since v <= 2^precision,
// (size * v) / 2^precision - 1 <= size - 1 < 2^32.
// For lower bound of b, note that 1 <= v, 2^16 <= size, and 16 <= precision.
// Therefore (size * v) / 2^precision - 1 >= 2^16 / 2^precision - 1 >= 0.
// The new interval is [base + a, base + b] = [base + a, base + b + 1).
base_ += a; // May overflow.
size_minus1_ = b - a;
const bool base_overflow = (base_ < a);
// The encoder has two states. Let's call them state 0 and state 1.
// State 0 is when base < base + size <= 2^32.
// State 1 is when base < 2^32 < base + size.
//
// The encoder initially starts in state 0, with base = 0, size = 2^32.
//
// TODO(ssjhv): Requires some profiling, but the encoder stays in state 0
// most of the time. Should optimize code for state 0.
//
// Each Encode() has up to two places where the interval changes:
// #1. Refine the interval. [base, base + size) -> [base + a, base + b + 1).
// #2. Expand interval if the new size is too small,
// and each change may cause a state transition.
//
// First, consider when the current state is 0.
//
// In this case, the next state after #1 is always state 0, since refining
// interval only shrinks the interval, therefore new_base + new_size <= 2^32.
//
// Let us explain #2.
//
// Recall that at the beginning of each Encode(), the encoder requires
// 2^16 < size <= 2^32. As precision <= 16, the new interval size can be as
// small as 1, but never zero.
//
// To keep size above 2^16, if new size is smaller than or equal to 2^16, the
// encoder would left-shift base and size by 16 bits: size' <- size * 2^16.
// Note that new size' is now in the range [2^16, 2^32].
//
// Since size is left-shifted, the same should be applied to base as well.
// However, after the left-shift, base will then contain 48 bits instead of 32
// bits. Therefore prior to the shift, The upper 16 bits in base should be
// stored somewhere else.
//
// If the upper 16 bits of all values in the interval were the same, i.e., if
// base[32:16] == (base + size - 1)[32:16], then base[32:16] can be written
// out to `output` string, since any further Encode() only narrows down the
// interval and that 16 bits would never change.
//
// If the upper 16 bits were not all the same, since this happens only when
// size <= 2^16, the upper 16 bits may differ only by one, i.e.,
// base[32:16] + 1 == (base + size - 1)[32:16]. At this stage, it is not
// determined yet whether base[32:16] should be written to the output or
// (base[32:16] + 1) should be written to the output. In this case,
// (base[32:16] + 1) is temporarily stored in `delay`, and base is
// left-shifted by 16 bits.
//
// In the latter case, the condition implies that (base // 2^16) and
// ((base + size - 1) // 2^16) were different. Therefore after left-shift by
// 16 bits, the new (base + size) is greater than 2^32, i.e., the encoder
// transition to state 1.
//
// ==== Summary ====
// To detect the current encoder state,
// state 0: delay == 0 iff (base mod 2^32) < (base + size) mod 2^32,
// state 1: delay != 0 iff (base + size) mod 2^32 <= base mod 2^32,
// because size <= 2^32.
//
// ==== Summary for state 0 ====
// 1. Interval refinement does not cause state transition.
// 2. Interval expansion may cause state transition, depending on the upper 16
// bits of base and base + size - 1.
//
// Now suppose the previous state was 1. This means that
// base <= 2^32 < base + size.
//
// When in state 1, an interval refinement may trigger state transition.
// After Encode() refines the interval, there are three possibilities:
// #1. base <= 2^32 < base + size (unchanged),
// #2. 2^32 <= base < base + size (base overflowed),
// #3. base < base + size <= 2^32 (base + size - 1 underflowed).
//
// In case #1, the encoder remains in state 1.
// In case #2 or #3, the encoder state changes to state 0.
//
// ==== State transition for interval refinement ====
// 1. state 0 -> state 0,
// 2. state 1 -> state 0 or state 1.
//
// Therefore if the new state is 1, then the previous state must have been
// state 1.
if (base_ + size_minus1_ < base_) {
// If statement checked if 2^32 < base + size. The new state is 1, hence the
// previous state was also state 1.
DCHECK_NE(((base_ - a) + size) >> 32, 0);
DCHECK_NE(delay_ & 0xFFFF, 0);
// Like in state 0, if the new size is <= 2^16, then base and size should
// be left-shifted by 16 bits. Combine the conditions
// base <= 2^32 < base + size and size <= 2^16 to conclude that
// base[32:16] >= 0xFFFF and (base + size - 1)[32:16] = 0x0000.
//
// Note that 2^32 - base < size, and since base is at least 0xFFFF0000,
// 2^16 - base[16:0] < size. Let base' and size' be the new base and size
// after the bit-shift. Then 2^32 - base' < size' => 2^32 < base' + size'.
// Therefore the encoder remains in state 1.
//
// Lastly, `delay` is modified. Conceptually, delay has to be changed to
// delay' <- delay * 2^16 + (base + size - 1)[32:16].
// Since we know above that (base + size - 1)[32:16] = 0x0000, there is no
// need to explicitly do the computation above, but rather store how many
// trailing zeros there were. For this reason, the lower 16 bits of
// `delay` stores the delayed value when state changed from 0 to 1, and
// delay[32:16] stores the # of trailing zeros (in bytes).
//
// ==== State transition for interval expansion ====
// 1. state 0 -> state 0 or state 1,
// 2. state 1 -> state 1.
if (size_minus1_ >> 16 == 0) {
DCHECK_EQ(base_ >> 16, 0xFFFF);
base_ <<= 16;
size_minus1_ <<= 16;
size_minus1_ |= 0xFFFF;
// TODO(ssjhv): It is possible that for very long input, delay
// overflow during below. If overflow is detected, this delay is too
// long the encoder should forcefully move to state 0. In such case,
// base can be raised to 2^32 (force case #2), or (base + size) can be
// lowered to 2^32 (force case #3), depending on which transition
// keeps size larger.
CHECK_LT(delay_, static_cast<uint64_t>(1) << 62);
delay_ += 0x20000; // Two more bytes of zeros. Check overflow?
}
return;
}
// If reached here, the current state is 0.
// First handle the case when the previous state was state 1.
if (delay_ != 0) {
// In case #2 or #3, the encoder state changes to state 0. Recall that when
// the encoder state changed from state 0 to state 1, the top 16 bits of
// (base + size - 1) was temporarily stored in `delay`, because the output
// could be either (delay - 1) or (delay).
//
// And from above, the delayed value encoded in `delay` is
// delay' <- delay[16:0] * 2^(8 * delay[MAX:16])
//
// In case #2, the interval moved below 2^32. So (delay' - 1) is the
// converged value after interval refinements. Write out
// (delay[16:0] - 1) and write (8 * delay[MAX:16]) bytes of 0xFF.
//
// In case #3, the interval moved above 2^32. So delay' is the converged
// value after interval refinement. Write out delay[16:0] and write
// (8 * delay[MAX:16]) bytes of 0x00.
if (base_overflow) {
// Case #2.
DCHECK_NE((static_cast<uint64_t>(base_ - a) + a) >> 32, 0);
sink->push_back(static_cast<char>(delay_ >> 8));
sink->push_back(static_cast<char>(delay_ >> 0));
sink->append(delay_ >> 16, static_cast<char>(0));
} else {
// Case #3.
DCHECK_EQ(static_cast<uint64_t>(base_ + size_minus1_) >> 32, 0);
--delay_;
sink->push_back(static_cast<char>(delay_ >> 8));
sink->push_back(static_cast<char>(delay_ >> 0));
sink->append(delay_ >> 16, static_cast<char>(0xFF));
}
// Reset to state 0.
delay_ = 0;
}
if (size_minus1_ >> 16 == 0) {
const uint32_t top = base_ >> 16;
base_ <<= 16;
size_minus1_ <<= 16;
size_minus1_ |= 0xFFFF;
if (base_ <= base_ + size_minus1_) {
// Still in state 0. Write the top 16 bits.
sink->push_back(static_cast<char>(top >> 8));
sink->push_back(static_cast<char>(top));
} else {
// New state is 1.
DCHECK_LT(top, 0xFFFF);
delay_ = top + 1;
}
}
}