neuralcompression/entropy_coders/jax_arithemetic_coder.py [163:173]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        branch_idx = lax.argmax(
            jnp.array(
                [
                    (high < ONE_HALF),
                    (low >= ONE_HALF),
                    (low >= ONE_FOURTH) * (high < THREE_FOURTHS),
                ]
            ),
            0,
            jnp.int32,
        )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



neuralcompression/entropy_coders/jax_arithemetic_coder.py [286:296]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        branch_idx = lax.argmax(
            jnp.array(
                [
                    (high < ONE_HALF),
                    (low >= ONE_HALF),
                    (low >= ONE_FOURTH) * (high < THREE_FOURTHS),
                ]
            ),
            0,
            jnp.int32,
        )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



