def uniform_quantize()

in diffq/uniform.py [0:0]


def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = torch.tensor(8.)):
    """
    Quantize the given weights over `bits` bits.

    Returns:
        - quantized levels
        - (min, max) range.

    """
    assert (bits >= 1).all() and (bits <= 15).all()
    num_levels = (2 ** bits.float()).long()
    mn = p.min().item()
    mx = p.max().item()
    p = (p - mn) / (mx - mn)  # put p in [0, 1]
    unit = 1 / (num_levels - 1)  # quantization unit
    levels = (p / unit).round()
    if (bits <= 8).all():
        levels = levels.byte()
    else:
        levels = levels.short()
    return levels, (mn, mx)