def decode()

in multiset_codec/rans.py [0:0]


def decode(ans_state, precisions):
    precisions = atleast_1d(precisions)
    head_, tail_ = ans_state

    # s' mod 2^r
    cfs = head_ % precisions

    def pop(starts, freqs):
        starts, freqs = map(atleast_1d, (starts, freqs))

        # calculate previous state  s = p*(s' // 2^r) + (s' % 2^r) - c
        head = freqs * (head_ // precisions) + cfs - starts

        # check which entries need renormalizing
        idxs = head < rans_l

        # how many 32*n bits do we need from the tail?
        n = np.sum(idxs)
        if n > 0:
            # new_head = 32*n bits from the tail
            # tail = previous tail, with 32*n less bits
            tail, new_head = stack_slice(tail_, n)

            # update LSBs of head, where needed
            head[idxs] = (head[idxs] << 32) | new_head
        else:
            tail = tail_
        return head, tail
    return cfs, pop