def __init__()

in tensorflow_compression/python/entropy_models/continuous_batched.py [0:0]


  def __init__(self,
               prior=None,
               coding_rank=None,
               compression=False,
               stateless=False,
               expected_grads=False,
               tail_mass=2**-8,
               range_coder_precision=12,
               dtype=None,
               prior_shape=None,
               cdf=None,
               cdf_offset=None,
               cdf_shapes=None,
               offset_heuristic=True,
               quantization_offset=None,
               laplace_tail_mass=0):
    """Initializes the instance.

    Args:
      prior: A `tfp.distributions.Distribution` object. A density model fitting
        the marginal distribution of the bottleneck data with additive uniform
        noise, which is shared a priori between the sender and the receiver. For
        best results, the distribution should be flexible enough to have a
        unit-width uniform distribution as a special case, since this is the
        marginal distribution for bottleneck dimensions that are constant. The
        distribution parameters may not depend on data (they must be either
        variables or constants).
      coding_rank: Integer. Number of innermost dimensions considered a coding
        unit. Each coding unit is compressed to its own bit string, and the
        bits in the __call__ method are summed over each coding unit.
      compression: Boolean. If set to `True`, the range coding tables used by
        `compress()` and `decompress()` will be built on instantiation. If set
        to `False`, these two methods will not be accessible.
      stateless: Boolean. If `False`, range coding tables are created as
        `Variable`s. This allows the entropy model to be serialized using the
        `SavedModel` protocol, so that both the encoder and the decoder use
        identical tables when loading the stored model. If `True`, creates range
        coding tables as `Tensor`s. This makes the entropy model stateless and
        allows it to be constructed within a `tf.function` body, for when the
        range coding tables are provided manually. If `compression=False`, then
        `stateless=True` is implied and the provided value is ignored.
      expected_grads: If True, will use analytical expected gradients during
        backpropagation w.r.t. additive uniform noise.
      tail_mass: Float. Approximate probability mass which is encoded using an
        Elias gamma code embedded into the range coder.
      range_coder_precision: Integer. Precision passed to the range coding op.
      dtype: `tf.dtypes.DType`. Data type of this entropy model (i.e. dtype of
        prior, decompressed values). Must be provided if `prior` is omitted.
      prior_shape: Batch shape of the prior (dimensions which are not assumed
        i.i.d.). Must be provided if `prior` is omitted.
      cdf: `tf.Tensor` or `None`. If provided, is used for range coding rather
        than tables built from the prior.
      cdf_offset: `tf.Tensor` or `None`. Must be provided along with `cdf`.
      cdf_shapes: Shapes of `cdf` and `cdf_offset`. If provided, empty range
        coding tables are created, which can then be restored using
        `set_weights`. Requires `compression=True` and `stateless=False`.
      offset_heuristic: Boolean. Whether to quantize to non-integer offsets
        heuristically determined from mode/median of prior. Set this to `False`
        if you are using soft quantization during training.
      quantization_offset: `tf.Tensor` or `None`. The quantization offsets to
        use. If provided (not `None`), then `offset_heuristic` is ineffective.
      laplace_tail_mass: Float. If positive, will augment the prior with a
        Laplace mixture for training stability. (experimental)
    """
    if not (prior is not None) == (dtype is None) == (prior_shape is None):
      raise ValueError(
          "Either `prior` or both `dtype` and `prior_shape` must be provided.")
    if (prior is None) + (cdf_shapes is None) + (cdf is None) != 2:
      raise ValueError(
          "Must provide exactly one of `prior`, `cdf`, or `cdf_shapes`.")
    if not compression and not (
        cdf is None and cdf_offset is None and cdf_shapes is None):
      raise ValueError("CDFs can't be provided with `compression=False`")
    if prior is not None and prior.event_shape.rank:
      raise ValueError("`prior` must be a (batch of) scalar distribution(s).")

    super().__init__(
        coding_rank=coding_rank,
        compression=compression,
        stateless=stateless,
        expected_grads=expected_grads,
        tail_mass=tail_mass,
        dtype=dtype if dtype is not None else prior.dtype,
        laplace_tail_mass=laplace_tail_mass,
    )
    self._prior = prior
    self._offset_heuristic = bool(offset_heuristic)
    self._prior_shape = tf.TensorShape(
        prior_shape if prior is None else prior.batch_shape)
    if self.coding_rank < self.prior_shape.rank:
      raise ValueError("`coding_rank` can't be smaller than `prior_shape`.")

    with self.name_scope:
      if cdf_shapes is not None:
        # `cdf_shapes` being set indicates that we are using the `SavedModel`
        # protocol, which can only provide JSON datatypes. So create a
        # placeholder value depending on whether `quantization_offset` was
        # `None` or not. For this purpose, we expect a Boolean (when in all
        # other cases, we expect either `None` or a tensor).
        assert isinstance(quantization_offset, bool)
        assert self.compression
        if quantization_offset:
          quantization_offset = tf.zeros(
              self.prior_shape_tensor, dtype=self.dtype)
        else:
          quantization_offset = None
      elif quantization_offset is not None:
        # If quantization offset is passed in manually, use it.
        pass
      elif self.offset_heuristic and self.compression:
        # For compression, we need to fix the offset value, so compute it here.
        if self._prior is None:
          raise ValueError(
              "To use the offset heuristic, a `prior` needs to be provided.")
        quantization_offset = helpers.quantization_offset(self.prior)
        # Optimization: if the quantization offset is zero, we don't need to
        # subtract/add it when quantizing, and we don't need to serialize its
        # value. Note that this code will only work in eager mode.
        if (tf.executing_eagerly() and
            tf.reduce_all(tf.equal(quantization_offset, 0.))):
          quantization_offset = None
        else:
          quantization_offset = tf.broadcast_to(
              quantization_offset, self.prior_shape_tensor)
      else:
        quantization_offset = None
      if quantization_offset is None:
        self._quantization_offset = None
      elif self.compression and not self.stateless:
        self._quantization_offset = tf.Variable(
            quantization_offset, dtype=self.dtype, trainable=False,
            name="quantization_offset")
      else:
        self._quantization_offset = tf.convert_to_tensor(
            quantization_offset, dtype=self.dtype, name="quantization_offset")
      if self.compression:
        if cdf is None and cdf_shapes is None:
          cdf, cdf_offset = self._build_tables(
              self.prior, range_coder_precision, offset=quantization_offset)
        self._init_compression(cdf, cdf_offset, cdf_shapes)