in nmt/train.py [0:0]
def train(hparams, scope=None, target_session=""):
"""Train a translation model."""
log_device_placement = hparams.log_device_placement
out_dir = hparams.out_dir
num_train_steps = hparams.num_train_steps
steps_per_stats = hparams.steps_per_stats
steps_per_external_eval = hparams.steps_per_external_eval
steps_per_eval = 10 * steps_per_stats
avg_ckpts = hparams.avg_ckpts
if not steps_per_external_eval:
steps_per_external_eval = 5 * steps_per_eval
# Create model
model_creator = get_model_creator(hparams)
train_model = model_helper.create_train_model(model_creator, hparams, scope)
eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
infer_model = model_helper.create_infer_model(model_creator, hparams, scope)
# Preload data for sample decoding.
dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
sample_src_data = inference.load_data(dev_src_file)
sample_tgt_data = inference.load_data(dev_tgt_file)
summary_name = "train_log"
model_dir = hparams.out_dir
# Log and output files
log_file = os.path.join(out_dir, "log_%d" % time.time())
log_f = tf.gfile.GFile(log_file, mode="a")
utils.print_out("# log_file=%s" % log_file, log_f)
# TensorFlow model
config_proto = utils.get_config_proto(
log_device_placement=log_device_placement,
num_intra_threads=hparams.num_intra_threads,
num_inter_threads=hparams.num_inter_threads)
train_sess = tf.Session(
target=target_session, config=config_proto, graph=train_model.graph)
eval_sess = tf.Session(
target=target_session, config=config_proto, graph=eval_model.graph)
infer_sess = tf.Session(
target=target_session, config=config_proto, graph=infer_model.graph)
with train_model.graph.as_default():
loaded_train_model, global_step = model_helper.create_or_load_model(
train_model.model, model_dir, train_sess, "train")
# Summary writer
summary_writer = tf.summary.FileWriter(
os.path.join(out_dir, summary_name), train_model.graph)
# First evaluation
run_full_eval(
model_dir, infer_model, infer_sess,
eval_model, eval_sess, hparams,
summary_writer, sample_src_data,
sample_tgt_data, avg_ckpts)
last_stats_step = global_step
last_eval_step = global_step
last_external_eval_step = global_step
# This is the training loop.
stats, info, start_train_time = before_train(
loaded_train_model, train_model, train_sess, global_step, hparams, log_f)
while global_step < num_train_steps:
### Run a step ###
start_time = time.time()
try:
step_result = loaded_train_model.train(train_sess)
hparams.epoch_step += 1
except tf.errors.OutOfRangeError:
# Finished going through the training dataset. Go to next epoch.
hparams.epoch_step = 0
utils.print_out(
"# Finished an epoch, step %d. Perform external evaluation" %
global_step)
run_sample_decode(infer_model, infer_sess, model_dir, hparams,
summary_writer, sample_src_data, sample_tgt_data)
run_external_eval(infer_model, infer_sess, model_dir, hparams,
summary_writer)
if avg_ckpts:
run_avg_external_eval(infer_model, infer_sess, model_dir, hparams,
summary_writer, global_step)
train_sess.run(
train_model.iterator.initializer,
feed_dict={train_model.skip_count_placeholder: 0})
continue
# Process step_result, accumulate stats, and write summary
global_step, info["learning_rate"], step_summary = update_stats(
stats, start_time, step_result)
summary_writer.add_summary(step_summary, global_step)
# Once in a while, we print statistics.
if global_step - last_stats_step >= steps_per_stats:
last_stats_step = global_step
is_overflow = process_stats(
stats, info, global_step, steps_per_stats, log_f)
print_step_info(" ", global_step, info, get_best_results(hparams),
log_f)
if is_overflow:
break
# Reset statistics
stats = init_stats()
if global_step - last_eval_step >= steps_per_eval:
last_eval_step = global_step
utils.print_out("# Save eval, global step %d" % global_step)
add_info_summaries(summary_writer, global_step, info)
# Save checkpoint
loaded_train_model.saver.save(
train_sess,
os.path.join(out_dir, "translate.ckpt"),
global_step=global_step)
# Evaluate on dev/test
run_sample_decode(infer_model, infer_sess,
model_dir, hparams, summary_writer, sample_src_data,
sample_tgt_data)
run_internal_eval(
eval_model, eval_sess, model_dir, hparams, summary_writer)
if global_step - last_external_eval_step >= steps_per_external_eval:
last_external_eval_step = global_step
# Save checkpoint
loaded_train_model.saver.save(
train_sess,
os.path.join(out_dir, "translate.ckpt"),
global_step=global_step)
run_sample_decode(infer_model, infer_sess,
model_dir, hparams, summary_writer, sample_src_data,
sample_tgt_data)
run_external_eval(
infer_model, infer_sess, model_dir,
hparams, summary_writer)
if avg_ckpts:
run_avg_external_eval(infer_model, infer_sess, model_dir, hparams,
summary_writer, global_step)
# Done training
loaded_train_model.saver.save(
train_sess,
os.path.join(out_dir, "translate.ckpt"),
global_step=global_step)
(result_summary, _, final_eval_metrics) = (
run_full_eval(
model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
summary_writer, sample_src_data, sample_tgt_data, avg_ckpts))
print_step_info("# Final, ", global_step, info, result_summary, log_f)
utils.print_time("# Done training!", start_train_time)
summary_writer.close()
utils.print_out("# Start evaluating saved best models.")
for metric in hparams.metrics:
best_model_dir = getattr(hparams, "best_" + metric + "_dir")
summary_writer = tf.summary.FileWriter(
os.path.join(best_model_dir, summary_name), infer_model.graph)
result_summary, best_global_step, _ = run_full_eval(
best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
summary_writer, sample_src_data, sample_tgt_data)
print_step_info("# Best %s, " % metric, best_global_step, info,
result_summary, log_f)
summary_writer.close()
if avg_ckpts:
best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir")
summary_writer = tf.summary.FileWriter(
os.path.join(best_model_dir, summary_name), infer_model.graph)
result_summary, best_global_step, _ = run_full_eval(
best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
hparams, summary_writer, sample_src_data, sample_tgt_data)
print_step_info("# Averaged Best %s, " % metric, best_global_step, info,
result_summary, log_f)
summary_writer.close()
return final_eval_metrics, global_step