in diffq/diffq.py [0:0]
def _register_param(self, name, param, module, other):
if other is not None:
return self.__class__._QuantizedParam(
name=name, param=param, module=module, logit=other.logit, other=other)
assert self.group_size == 0 or param.numel() % self.group_size == 0
# we want the initial number of bits to be init_bits.
if self.param == "noise":
noise_scale = 1 / (2 ** self.init_bits - 1)
t = (math.log(noise_scale) - math.log(self._max_noise)) / (
math.log(self._min_noise) - math.log(self._max_noise))
else:
t = (self.init_bits - self.min_bits) / (self.max_bits - self.min_bits)
assert 0 < t < 1
logit = torch.logit(torch.tensor(float(t)))
assert abs(self._get_bits(logit) - self.init_bits) < 1e-5
if self.group_size > 0:
nparam = param.numel() // self.group_size
else:
nparam = 1
logit = torch.nn.Parameter(
torch.full(
(nparam,),
logit,
device=param.device))
module.register_parameter(name + self.suffix, logit)
return self.__class__._QuantizedParam(
name=name, param=param, module=module, logit=logit, other=None)