in neuralcompression/entropy_coders/jax_arithemetic_coder.py [0:0]
def choose_branch_and_execute(vals):
"""if/else bad for jax, so we use a switch with argmin instead."""
high, low, pending_bits, int_array, byte, idx, bit_idx = vals
# this will return true for the first one of these that is true
branch_idx = lax.argmax(
jnp.array(
[
(high < ONE_HALF),
(low >= ONE_HALF),
(low >= ONE_FOURTH) * (high < THREE_FOURTHS),
]
),
0,
jnp.int32,
)
# execute chosen branch
high, low, pending_bits, int_array, byte, idx, bit_idx = lax.switch(
branch_idx,
branches,
(high, low, pending_bits, int_array, byte, idx, bit_idx),
)
# perform after-shifts
high <<= 1
high += 1
low <<= 1
high &= MAX_CODE
low &= MAX_CODE
return high, low, pending_bits, int_array, byte, idx, bit_idx