def _update()

in neuralcompression/layers/_continuous_entropy.py [0:0]


    def _update(self):
        quantization_offset = ncF.quantization_offset(self.prior)

        if isinstance(self.prior, UniformNoise):
            lower_tail = self.prior.lower_tail(self.tail_mass)
        else:
            lower_tail = ncF.lower_tail(self.prior, self.tail_mass)

        if isinstance(self.prior, UniformNoise):
            upper_tail = self.prior.upper_tail(self.tail_mass)
        else:
            upper_tail = ncF.upper_tail(self.prior, self.tail_mass)

        minimum = torch.floor(lower_tail - quantization_offset)
        minimum = minimum.to(torch.int32)
        minimum = torch.clamp_min(minimum, 0)

        maximum = torch.ceil(upper_tail - quantization_offset)
        maximum = maximum.to(torch.int32)
        maximum = torch.clamp_min(maximum, 0)

        pmf_start = minimum.to(self.prior_dtype) + quantization_offset
        pmf_start = pmf_start.to(torch.int32)

        pmf_sizes = maximum - minimum + 1

        maximum_pmf_size = torch.max(pmf_sizes).to(self.prior_dtype)
        maximum_pmf_size = maximum_pmf_size.to(torch.int32)

        samples = torch.arange(maximum_pmf_size).to(self.prior_dtype)
        samples = samples.reshape([-1] + len(self.context_shape) * [1])
        samples = samples + pmf_start

        if isinstance(self.prior, UniformNoise):
            pmfs = self.prior.prob(samples)
        else:
            pmfs = torch.exp(self.prior.log_prob(samples))

        pmf_sizes = torch.broadcast_to(pmf_sizes, self.context_shape)
        pmf_sizes = pmf_sizes.squeeze()

        cdf_sizes = pmf_sizes + 2

        cdf_offsets = torch.broadcast_to(minimum, self.context_shape)
        cdf_offsets = cdf_offsets.squeeze()

        cdfs = torch.zeros(
            (len(pmf_sizes), int(maximum_pmf_size) + 2),
            dtype=torch.int32,
            device=pmfs.device,
        )

        for index, (pmf, pmf_size) in enumerate(zip(pmfs, pmf_sizes)):
            pmf = pmf[:pmf_size]

            overflow = torch.clamp(
                1 - torch.sum(pmf, dim=0, keepdim=True),
                min=0.0,
            )

            pmf = torch.cat([pmf, overflow], dim=0)

            cdf = ncF.pmf_to_quantized_cdf(
                pmf,
                self._range_coder_precision,
            )

            cdfs[index, : cdf.size()[0]] = cdf

        return cdfs, cdf_sizes, cdf_offsets