in tensorflow_compression/python/entropy_models/universal.py [0:0]
def __init__(self,
prior_fn,
index_ranges,
parameter_fns,
coding_rank,
compression=False,
dtype=tf.float32,
laplace_tail_mass=0.0,
expected_grads=False,
tail_mass=2**-8,
range_coder_precision=12,
stateless=False,
num_noise_levels=15):
"""Initializes the instance.
Args:
prior_fn: A callable returning a `tfp.distributions.Distribution` object,
typically a `Distribution` class or factory function. This is a density
model fitting the marginal distribution of the bottleneck data with
additive uniform noise, which is shared a priori between the sender and
the receiver. For best results, the distributions should be flexible
enough to have a unit-width uniform distribution as a special case,
since this is the marginal distribution for bottleneck dimensions that
are constant. The callable will receive keyword arguments as determined
by `parameter_fns`.
index_ranges: Iterable of integers. Compared to `bottleneck`, `indexes`
in `__call__()` must have an additional trailing dimension, and the
values of the `k`th channel must be in the range `[0, index_ranges[k])`.
parameter_fns: Dict of strings to callables. Functions mapping `indexes`
to each distribution parameter. For each item, `indexes` is passed to
the callable, and the string key and return value make up one keyword
argument to `prior_fn`.
coding_rank: Integer. Number of innermost dimensions considered a coding
unit. Each coding unit is compressed to its own bit string, and the
`bits()` method sums over each coding unit.
compression: Boolean. If set to `True`, the range coding tables used by
`compress()` and `decompress()` will be built on instantiation. This
assumes eager mode (throws an error if in graph mode or inside a
`tf.function` call). If set to `False`, these two methods will not be
accessible.
dtype: `tf.dtypes.DType`. The data type of all floating-point computations
carried out in this class.
laplace_tail_mass: Float. If positive, will augment the prior with a
laplace mixture for training stability. (experimental)
expected_grads: If True, will use analytical expected gradients during
backpropagation w.r.t. additive uniform noise.
tail_mass: Float. Approximate probability mass which is encoded using an
Elias gamma code embedded into the range coder.
range_coder_precision: Integer. Precision passed to the range coding op.
stateless: Boolean. If True, creates range coding tables as `Tensor`s
rather than `Variable`s.
num_noise_levels: Integer. The number of levels used to quantize the
uniform noise.
"""
if coding_rank <= 0:
raise ValueError("`coding_rank` must be larger than 0.")
if not callable(prior_fn):
raise TypeError("`prior_fn` must be a class or factory function.")
for name, fn in parameter_fns.items():
if not isinstance(name, str):
raise TypeError("`parameter_fns` must have string keys.")
if not callable(fn):
raise TypeError(f"`parameter_fns['{name}']` must be callable.")
super().__init__(
coding_rank=coding_rank,
compression=compression,
stateless=stateless,
expected_grads=expected_grads,
tail_mass=tail_mass,
dtype=dtype,
laplace_tail_mass=laplace_tail_mass,
)
# Add extra indexes for noise levels.
self._index_ranges = tuple(
[num_noise_levels] + [int(r) for r in index_ranges])
if not self.index_ranges:
raise ValueError("`index_ranges` must have at least one element.")
self._prior_fn = prior_fn
self._parameter_fns = dict(parameter_fns)
self._num_noise_levels = num_noise_levels
with self.name_scope:
if self.compression:
index_ranges = self.index_ranges_without_offsets
indexes = [tf.range(r, dtype=self.dtype) for r in index_ranges]
indexes = tf.meshgrid(*indexes, indexing="ij")
indexes = tf.stack(indexes, axis=-1)
self._prior = self._make_prior(indexes)
offset = _range_coding_offsets(
self._num_noise_levels, self.prior.batch_shape, self.dtype)
cdf, cdf_offset = self._build_tables(
self.prior, range_coder_precision, offset=offset)
self._init_compression(cdf, cdf_offset, None)