in models/hific/model.py [0:0]
def build_model(self, input_image, input_images_d_steps=None):
"""Build model and losses and train_ops.
Args:
input_image: A single (B, H, W, C) image, in [0, 255]
input_images_d_steps: If training a discriminator, this is expected to
be a (B*N, H, W, C) stack of images, where N=number of sub batches.
See build_input.
Returns:
output_image and bitstrings if self.evaluation else None.
"""
if input_images_d_steps is None:
input_images_d_steps = []
else:
input_images_d_steps.set_shape(
self.input_spec["input_images_d_steps"].shape)
input_images_d_steps = tf.split(input_images_d_steps, self.num_steps_disc)
if self.evaluation and input_images_d_steps:
raise ValueError("Only need input_image for eval! {}".format(
input_images_d_steps))
input_image.set_shape(self.input_spec["input_image"].shape)
self.build_transforms()
if self.training:
self._lpips_loss = LPIPSLoss(self._lpips_weight_path)
self._lpips_loss_weight = self._config.loss_config.lpips_weight
if self._setup_discriminator:
self.build_discriminator()
# Global step needs to be created for train, val and eval.
global_step = tf.train.get_or_create_global_step()
# Compute output graph.
nodes_gen, bpp_pair, bitstrings = \
self._compute_compression_graph(input_image)
if self.evaluation:
tf.logging.info("Evaluation mode: build_model done.")
reconstruction = tf.clip_by_value(nodes_gen.reconstruction, 0, 255.)
return reconstruction, bitstrings
nodes_disc = [] # list of Nodes, one for every sub-batch of disc
for i, sub_batch in enumerate(input_images_d_steps):
with tf.name_scope("sub_batch_disc_{}".format(i)):
nodes, _, _ = self._compute_compression_graph(
sub_batch, create_summaries=False)
nodes_disc.append(nodes)
if self._auto_encoder_ckpt_path:
self._prepare_auto_encoder_restore()
# The following is inspired by compare_gan/gans/modular_gan.py:
# Let's say we want to train the discriminator for D steps for every 1 step
# of generator training. We do the unroll_graph=True options:
# The features given to the model_fn are split into
# D + 1 sub-batches. The code then creates D train_ops for the
# discriminator, each feeding a different sub-batch of features
# into the discriminator.
# The train_op for the generator then depends on all these D train_ops
# and uses the last (D+1 th) sub-batch.
# Note that the graph is only created once.
d_train_ops = []
if self._setup_discriminator:
tf.logging.info("Unrolling graph for discriminator")
self._global_step_disc = tf.get_variable(
"global_step_disc", [], dtype=global_step.dtype, trainable=False)
with tf.name_scope("steps"):
tf.summary.scalar("global_step", global_step)
tf.summary.scalar("global_step_disc", self._global_step_disc)
# Create optimizer once, and then call minimize on it multiple times
# within self._train_discriminator.
disc_optimizer = self._make_discriminator_optimizer(
self._global_step_disc)
for i, nodes in enumerate(nodes_disc):
with tf.name_scope("train_disc_{}".format(i + 1)):
with tf.control_dependencies(d_train_ops):
d_train_ops.append(
self._train_discriminator(
nodes, disc_optimizer, create_summaries=(i == 0)))
# Depend on `d_train_ops`, which ensures all `self._num_steps_disc` steps of
# the discriminator will run before the generator training op.
with tf.control_dependencies(d_train_ops):
train_op = self._train_generator(nodes_gen, bpp_pair, global_step)
if self.training:
self._train_op = train_op