def Multiset()

in multiset_codec/codecs.py [0:0]


def Multiset(symbol_codec: Codec) -> Codec:
    '''
    Encodes a multiset using bits-back coding.

    Symbols are sampled from the multiset with SamplingWithoutReplacement,
    and encoded sequentially with symbol_codec.
    '''
    swor_codec = SamplingWithoutReplacement()

    def encode(ans_state, multiset):
        while multiset:
            # 1) Sample, without replacement, a symbol using ANS decode.
            ans_state, symbol, multiset = \
                    swor_codec.decode(ans_state, multiset)

            # 2) Encode the selected symbol onto the same ANS state.
            (ans_state,) = symbol_codec.encode(ans_state, symbol)
        return (ans_state,)

    def decode(ans_state, multiset_size):
        multiset = ()
        for _ in range(multiset_size):
            # Decode symbol on top of stack (reverses step 2)
            ans_state, symbol = symbol_codec.decode(ans_state)

            # Encode bits used to sample symbol (reverses step 1)
            # This is the bits-back step!
            ans_state, multiset = \
                    swor_codec.encode(ans_state, symbol, multiset)
        return ans_state, multiset

    return Codec(encode, decode)