in nmt/model_helper.py [0:0]
def create_train_model(
model_creator, hparams, scope=None, num_workers=1, jobid=0,
extra_args=None):
"""Create train graph, model, and iterator."""
src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
src_vocab_file = hparams.src_vocab_file
tgt_vocab_file = hparams.tgt_vocab_file
graph = tf.Graph()
with graph.as_default(), tf.container(scope or "train"):
src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
src_vocab_file, tgt_vocab_file, hparams.share_vocab)
src_dataset = tf.data.TextLineDataset(tf.gfile.Glob(src_file))
tgt_dataset = tf.data.TextLineDataset(tf.gfile.Glob(tgt_file))
skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)
iterator = iterator_utils.get_iterator(
src_dataset,
tgt_dataset,
src_vocab_table,
tgt_vocab_table,
batch_size=hparams.batch_size,
sos=hparams.sos,
eos=hparams.eos,
random_seed=hparams.random_seed,
num_buckets=hparams.num_buckets,
src_max_len=hparams.src_max_len,
tgt_max_len=hparams.tgt_max_len,
skip_count=skip_count_placeholder,
num_shards=num_workers,
shard_index=jobid,
use_char_encode=hparams.use_char_encode)
# Note: One can set model_device_fn to
# `tf.train.replica_device_setter(ps_tasks)` for distributed training.
model_device_fn = None
if extra_args: model_device_fn = extra_args.model_device_fn
with tf.device(model_device_fn):
model = model_creator(
hparams,
iterator=iterator,
mode=tf.contrib.learn.ModeKeys.TRAIN,
source_vocab_table=src_vocab_table,
target_vocab_table=tgt_vocab_table,
scope=scope,
extra_args=extra_args)
return TrainModel(
graph=graph,
model=model,
iterator=iterator,
skip_count_placeholder=skip_count_placeholder)