in models/ms2020.py [0:0]
def call(self, x, training):
"""Computes rate and distortion losses."""
# Build the encoder (analysis) half of the hierarchical autoencoder.
y = self.analysis_transform(x)
y_shape = tf.shape(y)[1:-1]
z = self.hyper_analysis_transform(y)
num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[1:-1]), tf.float32)
# Build the entropy model for the hyperprior (z).
em_z = tfc.ContinuousBatchedEntropyModel(
self.hyperprior, coding_rank=3, compression=False,
offset_heuristic=False)
# When training, z_bpp is based on the noisy version of z (z_tilde).
_, z_bits = em_z(z, training=training)
z_bpp = tf.reduce_mean(z_bits) / num_pixels
# Use rounding (instead of uniform noise) to modify z before passing it
# to the hyper-synthesis transforms. Note that quantize() overrides the
# gradient to create a straight-through estimator.
z_hat = em_z.quantize(z)
# Build the decoder (synthesis) half of the hierarchical autoencoder.
latent_scales = self.hyper_synthesis_scale_transform(z_hat)
latent_means = self.hyper_synthesis_mean_transform(z_hat)
# Build a conditional entropy model for the slices.
em_y = tfc.LocationScaleIndexedEntropyModel(
tfc.NoisyNormal, num_scales=self.num_scales, scale_fn=self.scale_fn,
coding_rank=3, compression=False)
# En/Decode each slice conditioned on hyperprior and previous slices.
y_slices = tf.split(y, self.num_slices, axis=-1)
y_hat_slices = []
y_bpps = []
for slice_index, y_slice in enumerate(y_slices):
# Model may condition on only a subset of previous slices.
support_slices = (y_hat_slices if self.max_support_slices < 0 else
y_hat_slices[:self.max_support_slices])
# Predict mu and sigma for the current slice.
mean_support = tf.concat([latent_means] + support_slices, axis=-1)
mu = self.cc_mean_transforms[slice_index](mean_support)
mu = mu[:, :y_shape[0], :y_shape[1], :]
# Note that in this implementation, `sigma` represents scale indices,
# not actual scale values.
scale_support = tf.concat([latent_scales] + support_slices, axis=-1)
sigma = self.cc_scale_transforms[slice_index](scale_support)
sigma = sigma[:, :y_shape[0], :y_shape[1], :]
_, slice_bits = em_y(y_slice, sigma, loc=mu, training=training)
slice_bpp = tf.reduce_mean(slice_bits) / num_pixels
y_bpps.append(slice_bpp)
# For the synthesis transform, use rounding. Note that quantize()
# overrides the gradient to create a straight-through estimator.
y_hat_slice = em_y.quantize(y_slice, loc=mu)
# Add latent residual prediction (LRP).
lrp_support = tf.concat([mean_support, y_hat_slice], axis=-1)
lrp = self.lrp_transforms[slice_index](lrp_support)
lrp = 0.5 * tf.math.tanh(lrp)
y_hat_slice += lrp
y_hat_slices.append(y_hat_slice)
# Merge slices and generate the image reconstruction.
y_hat = tf.concat(y_hat_slices, axis=-1)
x_hat = self.synthesis_transform(y_hat)
# Total bpp is sum of bpp from hyperprior and all slices.
total_bpp = tf.add_n(y_bpps + [z_bpp])
# Mean squared error across pixels.
# Don't clip or round pixel values while training.
mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))
# Calculate and return the rate-distortion loss: R + lambda * D.
loss = total_bpp + self.lmbda * mse
return loss, total_bpp, mse