def SamplingWithoutReplacement()

in multiset_codec/codecs.py [0:0]


def SamplingWithoutReplacement() -> Codec:
    '''
    Encodes and decodes onto the ANS state using the empirical
    distribution of symbols in the multiset.

    Before an encode, the symbol to be encoded is inserted into the multiset.
    After a decode, the decoded symbol is removed from the multiset. Therefore,
    a decode performs sampling without replacement, while encode inverts it.

    The context is the multiset, i.e. *context = multiset
    '''
    def encode(ans_state, symbol, multiset):
        multiset, (start, freq) = insert_then_forward_lookup(multiset, symbol)
        multiset_size = multiset[0]
        ans_state = rans.encode(ans_state, start, freq, multiset_size)
        return ans_state, multiset

    def decode(ans_state, multiset):
        multiset_size = multiset[0]
        cdf_value, decode_ = rans.decode(ans_state, multiset_size)
        multiset, (start, freq), symbol = \
                reverse_lookup_then_remove(multiset, cdf_value[0])
        ans_state = decode_(start, freq)
        return ans_state, symbol, multiset

    return substack(Codec(encode, decode), lambda head: head[:1])