def choose_branch_and_execute()

in neuralcompression/entropy_coders/jax_arithemetic_coder.py [0:0]


    def choose_branch_and_execute(vals):
        """if/else bad for jax, so we use a switch with argmin instead."""
        high, low, pending_bits, int_array, byte, idx, bit_idx = vals

        # this will return true for the first one of these that is true
        branch_idx = lax.argmax(
            jnp.array(
                [
                    (high < ONE_HALF),
                    (low >= ONE_HALF),
                    (low >= ONE_FOURTH) * (high < THREE_FOURTHS),
                ]
            ),
            0,
            jnp.int32,
        )

        # execute chosen branch
        high, low, pending_bits, int_array, byte, idx, bit_idx = lax.switch(
            branch_idx,
            branches,
            (high, low, pending_bits, int_array, byte, idx, bit_idx),
        )

        # perform after-shifts
        high <<= 1
        high += 1
        low <<= 1
        high &= MAX_CODE
        low &= MAX_CODE

        return high, low, pending_bits, int_array, byte, idx, bit_idx