def encode_decode()

in models/toy_sources/ntc.py [0:0]


  def encode_decode(self, x, dither_rate, dither_dist, soft_round,
                    guess_offset=None, offset=0., seed=None):
    if guess_offset is None:
      guess_offset = self.guess_offset
    # It doesn't make sense to use both guess_offset and soft_round.
    assert not (guess_offset and soft_round)

    def perturb(inputs, dither, prior, offset):
      if dither:
        if soft_round:
          inputs = tfc.soft_round(inputs, alpha=self.alpha)
        inputs += tf.random.uniform(
            tf.shape(inputs), -.5, .5, dtype=self.dtype, seed=seed)
        if soft_round:
          inputs = tfc.soft_round_conditional_mean(inputs, alpha=self.alpha)
        return inputs
      else:
        if guess_offset:
          offset += tfc.quantization_offset(prior)
        return st_round(inputs - offset) + offset

    assert x.shape[-1] == self.ndim_source
    y = self.analysis(x)

    rates = 0.
    prior = self.prior(soft_round=soft_round)

    y_dist = perturb(y, dither_dist, prior, offset)
    if dither_rate == dither_dist:
      y_rate = y_dist
    else:
      y_rate = perturb(y, dither_rate, prior, offset)

    x_hat = self.synthesis(y_dist)
    log_probs = prior.log_prob(y_rate)
    rates += tf.reduce_sum(log_probs, axis=-1) / tf.cast(
        -tf.math.log(2.), self.dtype)

    return y_dist, x_hat, rates