in nmt/utils/common_test_utils.py [0:0]
def create_test_hparams(unit_type="lstm",
encoder_type="uni",
num_layers=4,
attention="",
attention_architecture=None,
use_residual=False,
inference_indices=None,
num_translations_per_input=1,
beam_width=0,
init_op="uniform"):
"""Create training and inference test hparams."""
num_residual_layers = 0
if use_residual:
# TODO(rzhao): Put num_residual_layers computation logic into
# `model_utils.py`, so we can also test it here.
num_residual_layers = 2
standard_hparams = standard_hparams_utils.create_standard_hparams()
# Networks
standard_hparams.num_units = 5
standard_hparams.num_encoder_layers = num_layers
standard_hparams.num_decoder_layers = num_layers
standard_hparams.dropout = 0.5
standard_hparams.unit_type = unit_type
standard_hparams.encoder_type = encoder_type
standard_hparams.residual = use_residual
standard_hparams.num_residual_layers = num_residual_layers
# Attention mechanisms
standard_hparams.attention = attention
standard_hparams.attention_architecture = attention_architecture
# Train
standard_hparams.init_op = init_op
standard_hparams.num_train_steps = 1
standard_hparams.decay_scheme = ""
# Infer
standard_hparams.tgt_max_len_infer = 100
standard_hparams.beam_width = beam_width
standard_hparams.num_translations_per_input = num_translations_per_input
# Misc
standard_hparams.forget_bias = 0.0
standard_hparams.random_seed = 3
standard_hparams.language_model = False
# Vocab
standard_hparams.src_vocab_size = 5
standard_hparams.tgt_vocab_size = 5
standard_hparams.eos = "</s>"
standard_hparams.sos = "<s>"
standard_hparams.src_vocab_file = ""
standard_hparams.tgt_vocab_file = ""
standard_hparams.src_embed_file = ""
standard_hparams.tgt_embed_file = ""
# For inference.py test
standard_hparams.subword_option = "bpe"
standard_hparams.src = "src"
standard_hparams.tgt = "tgt"
standard_hparams.src_max_len = 400
standard_hparams.tgt_eos_id = 0
standard_hparams.inference_indices = inference_indices
return standard_hparams