in tensorflow_compression/python/entropy_models/continuous_batched.py [0:0]
def __init__(self,
prior=None,
coding_rank=None,
compression=False,
stateless=False,
expected_grads=False,
tail_mass=2**-8,
range_coder_precision=12,
dtype=None,
prior_shape=None,
cdf=None,
cdf_offset=None,
cdf_shapes=None,
offset_heuristic=True,
quantization_offset=None,
laplace_tail_mass=0):
"""Initializes the instance.
Args:
prior: A `tfp.distributions.Distribution` object. A density model fitting
the marginal distribution of the bottleneck data with additive uniform
noise, which is shared a priori between the sender and the receiver. For
best results, the distribution should be flexible enough to have a
unit-width uniform distribution as a special case, since this is the
marginal distribution for bottleneck dimensions that are constant. The
distribution parameters may not depend on data (they must be either
variables or constants).
coding_rank: Integer. Number of innermost dimensions considered a coding
unit. Each coding unit is compressed to its own bit string, and the
bits in the __call__ method are summed over each coding unit.
compression: Boolean. If set to `True`, the range coding tables used by
`compress()` and `decompress()` will be built on instantiation. If set
to `False`, these two methods will not be accessible.
stateless: Boolean. If `False`, range coding tables are created as
`Variable`s. This allows the entropy model to be serialized using the
`SavedModel` protocol, so that both the encoder and the decoder use
identical tables when loading the stored model. If `True`, creates range
coding tables as `Tensor`s. This makes the entropy model stateless and
allows it to be constructed within a `tf.function` body, for when the
range coding tables are provided manually. If `compression=False`, then
`stateless=True` is implied and the provided value is ignored.
expected_grads: If True, will use analytical expected gradients during
backpropagation w.r.t. additive uniform noise.
tail_mass: Float. Approximate probability mass which is encoded using an
Elias gamma code embedded into the range coder.
range_coder_precision: Integer. Precision passed to the range coding op.
dtype: `tf.dtypes.DType`. Data type of this entropy model (i.e. dtype of
prior, decompressed values). Must be provided if `prior` is omitted.
prior_shape: Batch shape of the prior (dimensions which are not assumed
i.i.d.). Must be provided if `prior` is omitted.
cdf: `tf.Tensor` or `None`. If provided, is used for range coding rather
than tables built from the prior.
cdf_offset: `tf.Tensor` or `None`. Must be provided along with `cdf`.
cdf_shapes: Shapes of `cdf` and `cdf_offset`. If provided, empty range
coding tables are created, which can then be restored using
`set_weights`. Requires `compression=True` and `stateless=False`.
offset_heuristic: Boolean. Whether to quantize to non-integer offsets
heuristically determined from mode/median of prior. Set this to `False`
if you are using soft quantization during training.
quantization_offset: `tf.Tensor` or `None`. The quantization offsets to
use. If provided (not `None`), then `offset_heuristic` is ineffective.
laplace_tail_mass: Float. If positive, will augment the prior with a
Laplace mixture for training stability. (experimental)
"""
if not (prior is not None) == (dtype is None) == (prior_shape is None):
raise ValueError(
"Either `prior` or both `dtype` and `prior_shape` must be provided.")
if (prior is None) + (cdf_shapes is None) + (cdf is None) != 2:
raise ValueError(
"Must provide exactly one of `prior`, `cdf`, or `cdf_shapes`.")
if not compression and not (
cdf is None and cdf_offset is None and cdf_shapes is None):
raise ValueError("CDFs can't be provided with `compression=False`")
if prior is not None and prior.event_shape.rank:
raise ValueError("`prior` must be a (batch of) scalar distribution(s).")
super().__init__(
coding_rank=coding_rank,
compression=compression,
stateless=stateless,
expected_grads=expected_grads,
tail_mass=tail_mass,
dtype=dtype if dtype is not None else prior.dtype,
laplace_tail_mass=laplace_tail_mass,
)
self._prior = prior
self._offset_heuristic = bool(offset_heuristic)
self._prior_shape = tf.TensorShape(
prior_shape if prior is None else prior.batch_shape)
if self.coding_rank < self.prior_shape.rank:
raise ValueError("`coding_rank` can't be smaller than `prior_shape`.")
with self.name_scope:
if cdf_shapes is not None:
# `cdf_shapes` being set indicates that we are using the `SavedModel`
# protocol, which can only provide JSON datatypes. So create a
# placeholder value depending on whether `quantization_offset` was
# `None` or not. For this purpose, we expect a Boolean (when in all
# other cases, we expect either `None` or a tensor).
assert isinstance(quantization_offset, bool)
assert self.compression
if quantization_offset:
quantization_offset = tf.zeros(
self.prior_shape_tensor, dtype=self.dtype)
else:
quantization_offset = None
elif quantization_offset is not None:
# If quantization offset is passed in manually, use it.
pass
elif self.offset_heuristic and self.compression:
# For compression, we need to fix the offset value, so compute it here.
if self._prior is None:
raise ValueError(
"To use the offset heuristic, a `prior` needs to be provided.")
quantization_offset = helpers.quantization_offset(self.prior)
# Optimization: if the quantization offset is zero, we don't need to
# subtract/add it when quantizing, and we don't need to serialize its
# value. Note that this code will only work in eager mode.
if (tf.executing_eagerly() and
tf.reduce_all(tf.equal(quantization_offset, 0.))):
quantization_offset = None
else:
quantization_offset = tf.broadcast_to(
quantization_offset, self.prior_shape_tensor)
else:
quantization_offset = None
if quantization_offset is None:
self._quantization_offset = None
elif self.compression and not self.stateless:
self._quantization_offset = tf.Variable(
quantization_offset, dtype=self.dtype, trainable=False,
name="quantization_offset")
else:
self._quantization_offset = tf.convert_to_tensor(
quantization_offset, dtype=self.dtype, name="quantization_offset")
if self.compression:
if cdf is None and cdf_shapes is None:
cdf, cdf_offset = self._build_tables(
self.prior, range_coder_precision, offset=quantization_offset)
self._init_compression(cdf, cdf_offset, cdf_shapes)