in tensorflow_compression/python/entropy_models/continuous_indexed.py [0:0]
def __init__(self,
prior_fn,
index_ranges,
parameter_fns,
coding_rank,
channel_axis=-1,
compression=False,
stateless=False,
expected_grads=False,
tail_mass=2**-8,
range_coder_precision=12,
dtype=tf.float32,
laplace_tail_mass=0):
"""Initializes the instance.
Args:
prior_fn: A callable returning a `tfp.distributions.Distribution` object,
typically a `Distribution` class or factory function. This is a density
model fitting the marginal distribution of the bottleneck data with
additive uniform noise, which is shared a priori between the sender and
the receiver. For best results, the distributions should be flexible
enough to have a unit-width uniform distribution as a special case,
since this is the marginal distribution for bottleneck dimensions that
are constant. The callable will receive keyword arguments as determined
by `parameter_fns`.
index_ranges: Iterable of integers. `indexes` must have the same shape as
the bottleneck tensor, with an additional dimension at position
`channel_axis`. The values of the `k`th channel must be in the range
`[0, index_ranges[k])`.
parameter_fns: Dict of strings to callables. Functions mapping `indexes`
to each distribution parameter. For each item, `indexes` is passed to
the callable, and the string key and return value make up one keyword
argument to `prior_fn`.
coding_rank: Integer. Number of innermost dimensions considered a coding
unit. Each coding unit is compressed to its own bit string, and the
bits in the `__call__` method are summed over each coding unit.
channel_axis: Integer or `None`. Determines the position of the channel
axis in `indexes`. Defaults to the last dimension. If set to `None`,
the index tensor is expected to have the same shape as the bottleneck
tensor (only allowed when `index_ranges` has length 1).
compression: Boolean. If set to `True`, the range coding tables used by
`compress()` and `decompress()` will be built on instantiation. If set
to `False`, these two methods will not be accessible.
stateless: Boolean. If `False`, range coding tables are created as
`Variable`s. This allows the entropy model to be serialized using the
`SavedModel` protocol, so that both the encoder and the decoder use
identical tables when loading the stored model. If `True`, creates range
coding tables as `Tensor`s. This makes the entropy model stateless and
allows it to be constructed within a `tf.function` body, for when the
range coding tables are provided manually. If `compression=False`, then
`stateless=True` is implied and the provided value is ignored.
expected_grads: If True, will use analytical expected gradients during
backpropagation w.r.t. additive uniform noise.
tail_mass: Float. Approximate probability mass which is encoded using an
Elias gamma code embedded into the range coder.
range_coder_precision: Integer. Precision passed to the range coding op.
dtype: `tf.dtypes.DType`. Data type of this entropy model (i.e. dtype of
prior, decompressed values).
laplace_tail_mass: Float. If positive, will augment the prior with a
laplace mixture for training stability. (experimental)
"""
if not callable(prior_fn):
raise TypeError("`prior_fn` must be a class or factory function.")
for name, fn in parameter_fns.items():
if not isinstance(name, str):
raise TypeError("`parameter_fns` must have string keys.")
if not callable(fn):
raise TypeError(f"`parameter_fns['{name}']` must be callable.")
super().__init__(
coding_rank=coding_rank,
compression=compression,
stateless=stateless,
expected_grads=expected_grads,
tail_mass=tail_mass,
dtype=dtype,
laplace_tail_mass=laplace_tail_mass,
)
self._index_ranges = tuple(int(r) for r in index_ranges)
if not self.index_ranges:
raise ValueError("`index_ranges` must have at least one element.")
self._channel_axis = None if channel_axis is None else int(channel_axis)
if self.channel_axis is None and len(self.index_ranges) > 1:
raise ValueError(
"`channel_axis` can't be `None` for `len(index_ranges) > 1`.")
self._prior_fn = prior_fn
self._parameter_fns = dict(parameter_fns)
with self.name_scope:
if self.compression:
if self.channel_axis is None:
index_range, = index_ranges
indexes = tf.range(index_range, dtype=self.dtype)
else:
indexes = [tf.range(r, dtype=self.dtype) for r in index_ranges]
indexes = tf.meshgrid(*indexes, indexing="ij")
indexes = tf.stack(indexes, axis=self.channel_axis)
self._prior = self._make_prior(indexes)
cdf, cdf_offset = self._build_tables(self.prior, range_coder_precision)
self._init_compression(cdf, cdf_offset, None)