in neuralcompression/layers/_continuous_entropy.py [0:0]
def _update(self):
quantization_offset = ncF.quantization_offset(self.prior)
if isinstance(self.prior, UniformNoise):
lower_tail = self.prior.lower_tail(self.tail_mass)
else:
lower_tail = ncF.lower_tail(self.prior, self.tail_mass)
if isinstance(self.prior, UniformNoise):
upper_tail = self.prior.upper_tail(self.tail_mass)
else:
upper_tail = ncF.upper_tail(self.prior, self.tail_mass)
minimum = torch.floor(lower_tail - quantization_offset)
minimum = minimum.to(torch.int32)
minimum = torch.clamp_min(minimum, 0)
maximum = torch.ceil(upper_tail - quantization_offset)
maximum = maximum.to(torch.int32)
maximum = torch.clamp_min(maximum, 0)
pmf_start = minimum.to(self.prior_dtype) + quantization_offset
pmf_start = pmf_start.to(torch.int32)
pmf_sizes = maximum - minimum + 1
maximum_pmf_size = torch.max(pmf_sizes).to(self.prior_dtype)
maximum_pmf_size = maximum_pmf_size.to(torch.int32)
samples = torch.arange(maximum_pmf_size).to(self.prior_dtype)
samples = samples.reshape([-1] + len(self.context_shape) * [1])
samples = samples + pmf_start
if isinstance(self.prior, UniformNoise):
pmfs = self.prior.prob(samples)
else:
pmfs = torch.exp(self.prior.log_prob(samples))
pmf_sizes = torch.broadcast_to(pmf_sizes, self.context_shape)
pmf_sizes = pmf_sizes.squeeze()
cdf_sizes = pmf_sizes + 2
cdf_offsets = torch.broadcast_to(minimum, self.context_shape)
cdf_offsets = cdf_offsets.squeeze()
cdfs = torch.zeros(
(len(pmf_sizes), int(maximum_pmf_size) + 2),
dtype=torch.int32,
device=pmfs.device,
)
for index, (pmf, pmf_size) in enumerate(zip(pmfs, pmf_sizes)):
pmf = pmf[:pmf_size]
overflow = torch.clamp(
1 - torch.sum(pmf, dim=0, keepdim=True),
min=0.0,
)
pmf = torch.cat([pmf, overflow], dim=0)
cdf = ncF.pmf_to_quantized_cdf(
pmf,
self._range_coder_precision,
)
cdfs[index, : cdf.size()[0]] = cdf
return cdfs, cdf_sizes, cdf_offsets