in nmt/nmt.py [0:0]
def ensure_compatible_hparams(hparams, default_hparams, hparams_path=""):
"""Make sure the loaded hparams is compatible with new changes."""
default_hparams = utils.maybe_parse_standard_hparams(
default_hparams, hparams_path)
# Set num encoder/decoder layers (for old checkpoints)
if hasattr(hparams, "num_layers"):
if not hasattr(hparams, "num_encoder_layers"):
hparams.add_hparam("num_encoder_layers", hparams.num_layers)
if not hasattr(hparams, "num_decoder_layers"):
hparams.add_hparam("num_decoder_layers", hparams.num_layers)
# For compatible reason, if there are new fields in default_hparams,
# we add them to the current hparams
default_config = default_hparams.values()
config = hparams.values()
for key in default_config:
if key not in config:
hparams.add_hparam(key, default_config[key])
# Update all hparams' keys if override_loaded_hparams=True
if getattr(default_hparams, "override_loaded_hparams", None):
overwritten_keys = default_config.keys()
else:
# For inference
overwritten_keys = INFERENCE_KEYS
for key in overwritten_keys:
if getattr(hparams, key) != default_config[key]:
utils.print_out("# Updating hparams.%s: %s -> %s" %
(key, str(getattr(hparams, key)),
str(default_config[key])))
setattr(hparams, key, default_config[key])
return hparams