def __init__()

in tensorflow_compression/python/entropy_models/universal.py [0:0]


  def __init__(self,
               prior_fn,
               index_ranges,
               parameter_fns,
               coding_rank,
               compression=False,
               dtype=tf.float32,
               laplace_tail_mass=0.0,
               expected_grads=False,
               tail_mass=2**-8,
               range_coder_precision=12,
               stateless=False,
               num_noise_levels=15):
    """Initializes the instance.

    Args:
      prior_fn: A callable returning a `tfp.distributions.Distribution` object,
        typically a `Distribution` class or factory function. This is a density
        model fitting the marginal distribution of the bottleneck data with
        additive uniform noise, which is shared a priori between the sender and
        the receiver. For best results, the distributions should be flexible
        enough to have a unit-width uniform distribution as a special case,
        since this is the marginal distribution for bottleneck dimensions that
        are constant. The callable will receive keyword arguments as determined
        by `parameter_fns`.
      index_ranges: Iterable of integers. Compared to `bottleneck`, `indexes`
        in `__call__()` must have an additional trailing dimension, and the
        values of the `k`th channel must be in the range `[0, index_ranges[k])`.
      parameter_fns: Dict of strings to callables. Functions mapping `indexes`
        to each distribution parameter. For each item, `indexes` is passed to
        the callable, and the string key and return value make up one keyword
        argument to `prior_fn`.
      coding_rank: Integer. Number of innermost dimensions considered a coding
        unit. Each coding unit is compressed to its own bit string, and the
        `bits()` method sums over each coding unit.
      compression: Boolean. If set to `True`, the range coding tables used by
        `compress()` and `decompress()` will be built on instantiation. This
        assumes eager mode (throws an error if in graph mode or inside a
        `tf.function` call). If set to `False`, these two methods will not be
        accessible.
      dtype: `tf.dtypes.DType`. The data type of all floating-point computations
        carried out in this class.
      laplace_tail_mass: Float. If positive, will augment the prior with a
        laplace mixture for training stability. (experimental)
      expected_grads: If True, will use analytical expected gradients during
        backpropagation w.r.t. additive uniform noise.
      tail_mass: Float. Approximate probability mass which is encoded using an
        Elias gamma code embedded into the range coder.
      range_coder_precision: Integer. Precision passed to the range coding op.
      stateless: Boolean. If True, creates range coding tables as `Tensor`s
        rather than `Variable`s.
      num_noise_levels: Integer. The number of levels used to quantize the
        uniform noise.
    """
    if coding_rank <= 0:
      raise ValueError("`coding_rank` must be larger than 0.")
    if not callable(prior_fn):
      raise TypeError("`prior_fn` must be a class or factory function.")
    for name, fn in parameter_fns.items():
      if not isinstance(name, str):
        raise TypeError("`parameter_fns` must have string keys.")
      if not callable(fn):
        raise TypeError(f"`parameter_fns['{name}']` must be callable.")

    super().__init__(
        coding_rank=coding_rank,
        compression=compression,
        stateless=stateless,
        expected_grads=expected_grads,
        tail_mass=tail_mass,
        dtype=dtype,
        laplace_tail_mass=laplace_tail_mass,
    )
    # Add extra indexes for noise levels.
    self._index_ranges = tuple(
        [num_noise_levels] + [int(r) for r in index_ranges])
    if not self.index_ranges:
      raise ValueError("`index_ranges` must have at least one element.")
    self._prior_fn = prior_fn
    self._parameter_fns = dict(parameter_fns)
    self._num_noise_levels = num_noise_levels

    with self.name_scope:
      if self.compression:
        index_ranges = self.index_ranges_without_offsets
        indexes = [tf.range(r, dtype=self.dtype) for r in index_ranges]
        indexes = tf.meshgrid(*indexes, indexing="ij")
        indexes = tf.stack(indexes, axis=-1)
        self._prior = self._make_prior(indexes)
        offset = _range_coding_offsets(
            self._num_noise_levels, self.prior.batch_shape, self.dtype)
        cdf, cdf_offset = self._build_tables(
            self.prior, range_coder_precision, offset=offset)
        self._init_compression(cdf, cdf_offset, None)