def __init__()

in tensorflow_compression/python/entropy_models/continuous_indexed.py [0:0]


  def __init__(self,
               prior_fn,
               index_ranges,
               parameter_fns,
               coding_rank,
               channel_axis=-1,
               compression=False,
               stateless=False,
               expected_grads=False,
               tail_mass=2**-8,
               range_coder_precision=12,
               dtype=tf.float32,
               laplace_tail_mass=0):
    """Initializes the instance.

    Args:
      prior_fn: A callable returning a `tfp.distributions.Distribution` object,
        typically a `Distribution` class or factory function. This is a density
        model fitting the marginal distribution of the bottleneck data with
        additive uniform noise, which is shared a priori between the sender and
        the receiver. For best results, the distributions should be flexible
        enough to have a unit-width uniform distribution as a special case,
        since this is the marginal distribution for bottleneck dimensions that
        are constant. The callable will receive keyword arguments as determined
        by `parameter_fns`.
      index_ranges: Iterable of integers. `indexes` must have the same shape as
        the bottleneck tensor, with an additional dimension at position
        `channel_axis`. The values of the `k`th channel must be in the range
        `[0, index_ranges[k])`.
      parameter_fns: Dict of strings to callables. Functions mapping `indexes`
        to each distribution parameter. For each item, `indexes` is passed to
        the callable, and the string key and return value make up one keyword
        argument to `prior_fn`.
      coding_rank: Integer. Number of innermost dimensions considered a coding
        unit. Each coding unit is compressed to its own bit string, and the
        bits in the `__call__` method are summed over each coding unit.
      channel_axis: Integer or `None`. Determines the position of the channel
        axis in `indexes`. Defaults to the last dimension. If set to `None`,
        the index tensor is expected to have the same shape as the bottleneck
        tensor (only allowed when `index_ranges` has length 1).
      compression: Boolean. If set to `True`, the range coding tables used by
        `compress()` and `decompress()` will be built on instantiation. If set
        to `False`, these two methods will not be accessible.
      stateless: Boolean. If `False`, range coding tables are created as
        `Variable`s. This allows the entropy model to be serialized using the
        `SavedModel` protocol, so that both the encoder and the decoder use
        identical tables when loading the stored model. If `True`, creates range
        coding tables as `Tensor`s. This makes the entropy model stateless and
        allows it to be constructed within a `tf.function` body, for when the
        range coding tables are provided manually. If `compression=False`, then
        `stateless=True` is implied and the provided value is ignored.
      expected_grads: If True, will use analytical expected gradients during
        backpropagation w.r.t. additive uniform noise.
      tail_mass: Float. Approximate probability mass which is encoded using an
        Elias gamma code embedded into the range coder.
      range_coder_precision: Integer. Precision passed to the range coding op.
      dtype: `tf.dtypes.DType`. Data type of this entropy model (i.e. dtype of
        prior, decompressed values).
      laplace_tail_mass: Float. If positive, will augment the prior with a
        laplace mixture for training stability. (experimental)
    """
    if not callable(prior_fn):
      raise TypeError("`prior_fn` must be a class or factory function.")
    for name, fn in parameter_fns.items():
      if not isinstance(name, str):
        raise TypeError("`parameter_fns` must have string keys.")
      if not callable(fn):
        raise TypeError(f"`parameter_fns['{name}']` must be callable.")

    super().__init__(
        coding_rank=coding_rank,
        compression=compression,
        stateless=stateless,
        expected_grads=expected_grads,
        tail_mass=tail_mass,
        dtype=dtype,
        laplace_tail_mass=laplace_tail_mass,
    )
    self._index_ranges = tuple(int(r) for r in index_ranges)
    if not self.index_ranges:
      raise ValueError("`index_ranges` must have at least one element.")
    self._channel_axis = None if channel_axis is None else int(channel_axis)
    if self.channel_axis is None and len(self.index_ranges) > 1:
      raise ValueError(
          "`channel_axis` can't be `None` for `len(index_ranges) > 1`.")
    self._prior_fn = prior_fn
    self._parameter_fns = dict(parameter_fns)

    with self.name_scope:
      if self.compression:
        if self.channel_axis is None:
          index_range, = index_ranges
          indexes = tf.range(index_range, dtype=self.dtype)
        else:
          indexes = [tf.range(r, dtype=self.dtype) for r in index_ranges]
          indexes = tf.meshgrid(*indexes, indexing="ij")
          indexes = tf.stack(indexes, axis=self.channel_axis)
        self._prior = self._make_prior(indexes)
        cdf, cdf_offset = self._build_tables(self.prior, range_coder_precision)
        self._init_compression(cdf, cdf_offset, None)