in multiset_codec/codecs.py [0:0]
def Multiset(symbol_codec: Codec) -> Codec:
'''
Encodes a multiset using bits-back coding.
Symbols are sampled from the multiset with SamplingWithoutReplacement,
and encoded sequentially with symbol_codec.
'''
swor_codec = SamplingWithoutReplacement()
def encode(ans_state, multiset):
while multiset:
# 1) Sample, without replacement, a symbol using ANS decode.
ans_state, symbol, multiset = \
swor_codec.decode(ans_state, multiset)
# 2) Encode the selected symbol onto the same ANS state.
(ans_state,) = symbol_codec.encode(ans_state, symbol)
return (ans_state,)
def decode(ans_state, multiset_size):
multiset = ()
for _ in range(multiset_size):
# Decode symbol on top of stack (reverses step 2)
ans_state, symbol = symbol_codec.decode(ans_state)
# Encode bits used to sample symbol (reverses step 1)
# This is the bits-back step!
ans_state, multiset = \
swor_codec.encode(ans_state, symbol, multiset)
return ans_state, multiset
return Codec(encode, decode)