def create_test_hparams()

in nmt/utils/common_test_utils.py [0:0]


def create_test_hparams(unit_type="lstm",
                        encoder_type="uni",
                        num_layers=4,
                        attention="",
                        attention_architecture=None,
                        use_residual=False,
                        inference_indices=None,
                        num_translations_per_input=1,
                        beam_width=0,
                        init_op="uniform"):
  """Create training and inference test hparams."""
  num_residual_layers = 0
  if use_residual:
    # TODO(rzhao): Put num_residual_layers computation logic into
    # `model_utils.py`, so we can also test it here.
    num_residual_layers = 2

  standard_hparams = standard_hparams_utils.create_standard_hparams()

  # Networks
  standard_hparams.num_units = 5
  standard_hparams.num_encoder_layers = num_layers
  standard_hparams.num_decoder_layers = num_layers
  standard_hparams.dropout = 0.5
  standard_hparams.unit_type = unit_type
  standard_hparams.encoder_type = encoder_type
  standard_hparams.residual = use_residual
  standard_hparams.num_residual_layers = num_residual_layers

  # Attention mechanisms
  standard_hparams.attention = attention
  standard_hparams.attention_architecture = attention_architecture

  # Train
  standard_hparams.init_op = init_op
  standard_hparams.num_train_steps = 1
  standard_hparams.decay_scheme = ""

  # Infer
  standard_hparams.tgt_max_len_infer = 100
  standard_hparams.beam_width = beam_width
  standard_hparams.num_translations_per_input = num_translations_per_input

  # Misc
  standard_hparams.forget_bias = 0.0
  standard_hparams.random_seed = 3
  standard_hparams.language_model = False

  # Vocab
  standard_hparams.src_vocab_size = 5
  standard_hparams.tgt_vocab_size = 5
  standard_hparams.eos = "</s>"
  standard_hparams.sos = "<s>"
  standard_hparams.src_vocab_file = ""
  standard_hparams.tgt_vocab_file = ""
  standard_hparams.src_embed_file = ""
  standard_hparams.tgt_embed_file = ""

  # For inference.py test
  standard_hparams.subword_option = "bpe"
  standard_hparams.src = "src"
  standard_hparams.tgt = "tgt"
  standard_hparams.src_max_len = 400
  standard_hparams.tgt_eos_id = 0
  standard_hparams.inference_indices = inference_indices
  return standard_hparams