in diffq/uniform.py [0:0]
def uniform_quantize(p: torch.Tensor, bits: torch.Tensor = torch.tensor(8.)):
"""
Quantize the given weights over `bits` bits.
Returns:
- quantized levels
- (min, max) range.
"""
assert (bits >= 1).all() and (bits <= 15).all()
num_levels = (2 ** bits.float()).long()
mn = p.min().item()
mx = p.max().item()
p = (p - mn) / (mx - mn) # put p in [0, 1]
unit = 1 / (num_levels - 1) # quantization unit
levels = (p / unit).round()
if (bits <= 8).all():
levels = levels.byte()
else:
levels = levels.short()
return levels, (mn, mx)