in nmt/nmt.py [0:0]
def create_hparams(flags):
"""Create training hparams."""
return tf.contrib.training.HParams(
# Data
src=flags.src,
tgt=flags.tgt,
train_prefix=flags.train_prefix,
dev_prefix=flags.dev_prefix,
test_prefix=flags.test_prefix,
vocab_prefix=flags.vocab_prefix,
embed_prefix=flags.embed_prefix,
out_dir=flags.out_dir,
# Networks
num_units=flags.num_units,
num_encoder_layers=(flags.num_encoder_layers or flags.num_layers),
num_decoder_layers=(flags.num_decoder_layers or flags.num_layers),
dropout=flags.dropout,
unit_type=flags.unit_type,
encoder_type=flags.encoder_type,
residual=flags.residual,
time_major=flags.time_major,
num_embeddings_partitions=flags.num_embeddings_partitions,
# Attention mechanisms
attention=flags.attention,
attention_architecture=flags.attention_architecture,
output_attention=flags.output_attention,
pass_hidden_state=flags.pass_hidden_state,
# Train
optimizer=flags.optimizer,
num_train_steps=flags.num_train_steps,
batch_size=flags.batch_size,
init_op=flags.init_op,
init_weight=flags.init_weight,
max_gradient_norm=flags.max_gradient_norm,
learning_rate=flags.learning_rate,
warmup_steps=flags.warmup_steps,
warmup_scheme=flags.warmup_scheme,
decay_scheme=flags.decay_scheme,
colocate_gradients_with_ops=flags.colocate_gradients_with_ops,
num_sampled_softmax=flags.num_sampled_softmax,
# Data constraints
num_buckets=flags.num_buckets,
max_train=flags.max_train,
src_max_len=flags.src_max_len,
tgt_max_len=flags.tgt_max_len,
# Inference
src_max_len_infer=flags.src_max_len_infer,
tgt_max_len_infer=flags.tgt_max_len_infer,
infer_batch_size=flags.infer_batch_size,
# Advanced inference arguments
infer_mode=flags.infer_mode,
beam_width=flags.beam_width,
length_penalty_weight=flags.length_penalty_weight,
coverage_penalty_weight=flags.coverage_penalty_weight,
sampling_temperature=flags.sampling_temperature,
num_translations_per_input=flags.num_translations_per_input,
# Vocab
sos=flags.sos if flags.sos else vocab_utils.SOS,
eos=flags.eos if flags.eos else vocab_utils.EOS,
subword_option=flags.subword_option,
check_special_token=flags.check_special_token,
use_char_encode=flags.use_char_encode,
# Misc
forget_bias=flags.forget_bias,
num_gpus=flags.num_gpus,
epoch_step=0, # record where we were within an epoch.
steps_per_stats=flags.steps_per_stats,
steps_per_external_eval=flags.steps_per_external_eval,
share_vocab=flags.share_vocab,
metrics=flags.metrics.split(","),
log_device_placement=flags.log_device_placement,
random_seed=flags.random_seed,
override_loaded_hparams=flags.override_loaded_hparams,
num_keep_ckpts=flags.num_keep_ckpts,
avg_ckpts=flags.avg_ckpts,
language_model=flags.language_model,
num_intra_threads=flags.num_intra_threads,
num_inter_threads=flags.num_inter_threads,
)