in neuralcompression/entropy_coders/craystack/codecs.py [0:0]
def pop(msg_index, compressed_message, message, cdf_state):
# pop latents with prior p(latents)
compressed_message, latents, _ = latent_prior_codec.pop(
0,
compressed_message,
jnp.zeros(latent_shape, dtype=latent_prior_codec.message_dtype),
latent_prior_codec.cdf_state,
latent_prior_codec.allow_empty_pops,
)
# pop symbols with conditional likelihood p(symbols|latents)
obs_codec = obs_codec_maker(latents)
compressed_message, message, _ = obs_codec.pop(
msg_index,
compressed_message,
message,
obs_codec.cdf_state,
obs_codec.allow_empty_pops,
)
symbols = message[msg_index]
# push latents with approximate posterior q(latents|symbols)
# this is the bits-back step!
latent_posterior_codec = latent_posterior_codec_maker(symbols)
compressed_message, _ = latent_posterior_codec.push(
latents, compressed_message, latent_posterior_codec.cdf_state
)
return compressed_message, message, None