in models/toy_sources/ntc.py [0:0]
def encode_decode(self, x, dither_rate, dither_dist, soft_round,
guess_offset=None, offset=0., seed=None):
if guess_offset is None:
guess_offset = self.guess_offset
# It doesn't make sense to use both guess_offset and soft_round.
assert not (guess_offset and soft_round)
def perturb(inputs, dither, prior, offset):
if dither:
if soft_round:
inputs = tfc.soft_round(inputs, alpha=self.alpha)
inputs += tf.random.uniform(
tf.shape(inputs), -.5, .5, dtype=self.dtype, seed=seed)
if soft_round:
inputs = tfc.soft_round_conditional_mean(inputs, alpha=self.alpha)
return inputs
else:
if guess_offset:
offset += tfc.quantization_offset(prior)
return st_round(inputs - offset) + offset
assert x.shape[-1] == self.ndim_source
y = self.analysis(x)
rates = 0.
prior = self.prior(soft_round=soft_round)
y_dist = perturb(y, dither_dist, prior, offset)
if dither_rate == dither_dist:
y_rate = y_dist
else:
y_rate = perturb(y, dither_rate, prior, offset)
x_hat = self.synthesis(y_dist)
log_probs = prior.log_prob(y_rate)
rates += tf.reduce_sum(log_probs, axis=-1) / tf.cast(
-tf.math.log(2.), self.dtype)
return y_dist, x_hat, rates