def ByteArray()

in multiset_codec/codecs.py [0:0]


def ByteArray(max_array_size: int) -> Codec:
    '''
    Encodes and decodes an array of bytes onto the ANS state.

    First, the bytearray size is encoded using a uniform distribution in
    the interval [0, max_array_size). Then, the bytes are encoded in parallel
    using a uniform distribution in the interval [0, 256).
    '''

    size_codec = substack(Uniform(max_array_size), lambda h: h[:1])
    bytes_codec = lambda size: substack(Uniform(256), lambda h: h[:size])

    def encode(ans_state, bytes_array):
        bytes_ndarray = np.frombuffer(bytes_array, dtype=np.uint8)
        size = len(bytes_array)
        (ans_state,) = bytes_codec(size).encode(ans_state, bytes_ndarray)
        (ans_state,) = size_codec.encode(ans_state, size)
        return (ans_state,)

    def decode(ans_state):
        ans_state, size = size_codec.decode(ans_state)
        ans_state, bytes_ndarray = bytes_codec(size[0]).decode(ans_state)
        bytes_array = bytes_ndarray.astype(np.uint8).tobytes()
        return ans_state, bytes_array

    return Codec(encode, decode)