def _register_param()

in diffq/diffq.py [0:0]


    def _register_param(self, name, param, module, other):
        if other is not None:
            return self.__class__._QuantizedParam(
               name=name, param=param, module=module, logit=other.logit, other=other)
        assert self.group_size == 0 or param.numel() % self.group_size == 0
        # we want the initial number of bits to be init_bits.
        if self.param == "noise":
            noise_scale = 1 / (2 ** self.init_bits - 1)
            t = (math.log(noise_scale) - math.log(self._max_noise)) / (
                math.log(self._min_noise) - math.log(self._max_noise))
        else:
            t = (self.init_bits - self.min_bits) / (self.max_bits - self.min_bits)
        assert 0 < t < 1
        logit = torch.logit(torch.tensor(float(t)))
        assert abs(self._get_bits(logit) - self.init_bits) < 1e-5
        if self.group_size > 0:
            nparam = param.numel() // self.group_size
        else:
            nparam = 1
        logit = torch.nn.Parameter(
            torch.full(
                (nparam,),
                logit,
                device=param.device))
        module.register_parameter(name + self.suffix, logit)
        return self.__class__._QuantizedParam(
           name=name, param=param, module=module, logit=logit, other=None)