in models/hific/model.py [0:0]
def __init__(self,
config,
mode: ModelMode,
lpips_weight_path=None,
auto_encoder_ckpt_dir=None,
create_image_summaries=True):
"""Instantiate model.
Args:
config: A config, see configs.py
mode: Model mode.
lpips_weight_path: path to where LPIPS weights are stored or should be
stored. See helpers.ensure_lpips_weights_exist.
auto_encoder_ckpt_dir: If given, instantiate auto-encoder and probability
model from latest checkpoint in this folder.
create_image_summaries: Whether to create image summaries. Turn off to
save disk space.
"""
self._mode = mode
self._config = config
self._model_type = config.model_type
self._create_image_summaries = create_image_summaries
if not isinstance(self._model_type, ModelType):
raise ValueError("Invalid model_type: [{}]".format(
self._config.model_type))
self._auto_encoder_ckpt_path = None
self._auto_encoder_savers = None
if auto_encoder_ckpt_dir:
latest_ckpt = tf.train.latest_checkpoint(auto_encoder_ckpt_dir)
if not latest_ckpt:
raise ValueError(f"Did not find checkpoint in {auto_encoder_ckpt_dir}!")
self._auto_encoder_ckpt_path = latest_ckpt
if self.training and not lpips_weight_path:
lpips_weight_path = "lpips_weight__net-lin_alex_v0.1.pb"
self._lpips_weight_path = lpips_weight_path
self._transform_layers = []
self._entropy_layers = []
self._layers = None
self._encoder = None
self._decoder = None
self._discriminator = None
self._gan_loss_function = None
self._lpips_loss_weight = None
self._lpips_loss = None
self._entropy_model = None
self._optimize_entropy_vars = True
self._global_step_disc = None # global_step used for D training
self._setup_discriminator = (
self._model_type == ModelType.COMPRESSION_GAN
and (self.training or self.validation)) # No disc for evaluation.
if self._setup_discriminator:
self._num_steps_disc = self._config.num_steps_disc
if self._num_steps_disc == 0:
raise ValueError("model_type=={} but num_steps_disc == 0.".format(
self._model_type))
tf.logging.info(
"GAN Training enabled. Training discriminator for {} steps.".format(
self._num_steps_disc))
else:
self._num_steps_disc = 0
self.input_spec = {
"input_image":
tf.keras.layers.InputSpec(
dtype=tf.uint8,
shape=(None, None, None, 3))}
if self._setup_discriminator:
# This is an optional argument to build_model. If training a
# discriminator, this is expected to contain multiple sub-batches.
# See build_input for details.
self.input_spec["input_images_d_steps"] = tf.keras.layers.InputSpec(
dtype=tf.uint8,
shape=(None, None, None, 3))
self._gan_loss_function = compare_gan_loss_lib.non_saturating
self._loss_scaler = _LossScaler(
self._config,
ignore_schedules=not self.training and not self.validation)
self._train_op = None
self._hooks = []