in nmt/nmt.py [0:0]
def extend_hparams(hparams):
"""Add new arguments to hparams."""
# Sanity checks
if hparams.encoder_type == "bi" and hparams.num_encoder_layers % 2 != 0:
raise ValueError("For bi, num_encoder_layers %d should be even" %
hparams.num_encoder_layers)
if (hparams.attention_architecture in ["gnmt"] and
hparams.num_encoder_layers < 2):
raise ValueError("For gnmt attention architecture, "
"num_encoder_layers %d should be >= 2" %
hparams.num_encoder_layers)
if hparams.subword_option and hparams.subword_option not in ["spm", "bpe"]:
raise ValueError("subword option must be either spm, or bpe")
if hparams.infer_mode == "beam_search" and hparams.beam_width <= 0:
raise ValueError("beam_width must greater than 0 when using beam_search"
"decoder.")
if hparams.infer_mode == "sample" and hparams.sampling_temperature <= 0.0:
raise ValueError("sampling_temperature must greater than 0.0 when using"
"sample decoder.")
# Different number of encoder / decoder layers
assert hparams.num_encoder_layers and hparams.num_decoder_layers
if hparams.num_encoder_layers != hparams.num_decoder_layers:
hparams.pass_hidden_state = False
utils.print_out("Num encoder layer %d is different from num decoder layer"
" %d, so set pass_hidden_state to False" % (
hparams.num_encoder_layers,
hparams.num_decoder_layers))
# Set residual layers
num_encoder_residual_layers = 0
num_decoder_residual_layers = 0
if hparams.residual:
if hparams.num_encoder_layers > 1:
num_encoder_residual_layers = hparams.num_encoder_layers - 1
if hparams.num_decoder_layers > 1:
num_decoder_residual_layers = hparams.num_decoder_layers - 1
if hparams.encoder_type == "gnmt":
# The first unidirectional layer (after the bi-directional layer) in
# the GNMT encoder can't have residual connection due to the input is
# the concatenation of fw_cell and bw_cell's outputs.
num_encoder_residual_layers = hparams.num_encoder_layers - 2
# Compatible for GNMT models
if hparams.num_encoder_layers == hparams.num_decoder_layers:
num_decoder_residual_layers = num_encoder_residual_layers
_add_argument(hparams, "num_encoder_residual_layers",
num_encoder_residual_layers)
_add_argument(hparams, "num_decoder_residual_layers",
num_decoder_residual_layers)
# Language modeling
if getattr(hparams, "language_model", None):
hparams.attention = ""
hparams.attention_architecture = ""
hparams.pass_hidden_state = False
hparams.share_vocab = True
hparams.src = hparams.tgt
utils.print_out("For language modeling, we turn off attention and "
"pass_hidden_state; turn on share_vocab; set src to tgt.")
## Vocab
# Get vocab file names first
if hparams.vocab_prefix:
src_vocab_file = hparams.vocab_prefix + "." + hparams.src
tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt
else:
raise ValueError("hparams.vocab_prefix must be provided.")
# Source vocab
check_special_token = getattr(hparams, "check_special_token", True)
src_vocab_size, src_vocab_file = vocab_utils.check_vocab(
src_vocab_file,
hparams.out_dir,
check_special_token=check_special_token,
sos=hparams.sos,
eos=hparams.eos,
unk=vocab_utils.UNK)
# Target vocab
if hparams.share_vocab:
utils.print_out(" using source vocab for target")
tgt_vocab_file = src_vocab_file
tgt_vocab_size = src_vocab_size
else:
tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab(
tgt_vocab_file,
hparams.out_dir,
check_special_token=check_special_token,
sos=hparams.sos,
eos=hparams.eos,
unk=vocab_utils.UNK)
_add_argument(hparams, "src_vocab_size", src_vocab_size)
_add_argument(hparams, "tgt_vocab_size", tgt_vocab_size)
_add_argument(hparams, "src_vocab_file", src_vocab_file)
_add_argument(hparams, "tgt_vocab_file", tgt_vocab_file)
# Num embedding partitions
num_embeddings_partitions = getattr(hparams, "num_embeddings_partitions", 0)
_add_argument(hparams, "num_enc_emb_partitions", num_embeddings_partitions)
_add_argument(hparams, "num_dec_emb_partitions", num_embeddings_partitions)
# Pretrained Embeddings
_add_argument(hparams, "src_embed_file", "")
_add_argument(hparams, "tgt_embed_file", "")
if getattr(hparams, "embed_prefix", None):
src_embed_file = hparams.embed_prefix + "." + hparams.src
tgt_embed_file = hparams.embed_prefix + "." + hparams.tgt
if tf.gfile.Exists(src_embed_file):
utils.print_out(" src_embed_file %s exist" % src_embed_file)
hparams.src_embed_file = src_embed_file
utils.print_out(
"For pretrained embeddings, set num_enc_emb_partitions to 1")
hparams.num_enc_emb_partitions = 1
else:
utils.print_out(" src_embed_file %s doesn't exist" % src_embed_file)
if tf.gfile.Exists(tgt_embed_file):
utils.print_out(" tgt_embed_file %s exist" % tgt_embed_file)
hparams.tgt_embed_file = tgt_embed_file
utils.print_out(
"For pretrained embeddings, set num_dec_emb_partitions to 1")
hparams.num_dec_emb_partitions = 1
else:
utils.print_out(" tgt_embed_file %s doesn't exist" % tgt_embed_file)
# Evaluation
for metric in hparams.metrics:
best_metric_dir = os.path.join(hparams.out_dir, "best_" + metric)
tf.gfile.MakeDirs(best_metric_dir)
_add_argument(hparams, "best_" + metric, 0, update=False)
_add_argument(hparams, "best_" + metric + "_dir", best_metric_dir)
if getattr(hparams, "avg_ckpts", None):
best_metric_dir = os.path.join(hparams.out_dir, "avg_best_" + metric)
tf.gfile.MakeDirs(best_metric_dir)
_add_argument(hparams, "avg_best_" + metric, 0, update=False)
_add_argument(hparams, "avg_best_" + metric + "_dir", best_metric_dir)
return hparams