def ensure_compatible_hparams()

in nmt/nmt.py [0:0]


def ensure_compatible_hparams(hparams, default_hparams, hparams_path=""):
  """Make sure the loaded hparams is compatible with new changes."""
  default_hparams = utils.maybe_parse_standard_hparams(
      default_hparams, hparams_path)

  # Set num encoder/decoder layers (for old checkpoints)
  if hasattr(hparams, "num_layers"):
    if not hasattr(hparams, "num_encoder_layers"):
      hparams.add_hparam("num_encoder_layers", hparams.num_layers)
    if not hasattr(hparams, "num_decoder_layers"):
      hparams.add_hparam("num_decoder_layers", hparams.num_layers)

  # For compatible reason, if there are new fields in default_hparams,
  #   we add them to the current hparams
  default_config = default_hparams.values()
  config = hparams.values()
  for key in default_config:
    if key not in config:
      hparams.add_hparam(key, default_config[key])

  # Update all hparams' keys if override_loaded_hparams=True
  if getattr(default_hparams, "override_loaded_hparams", None):
    overwritten_keys = default_config.keys()
  else:
    # For inference
    overwritten_keys = INFERENCE_KEYS

  for key in overwritten_keys:
    if getattr(hparams, key) != default_config[key]:
      utils.print_out("# Updating hparams.%s: %s -> %s" %
                      (key, str(getattr(hparams, key)),
                       str(default_config[key])))
      setattr(hparams, key, default_config[key])
  return hparams