in nmt/nmt.py [0:0]
def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""):
"""Run main."""
# Job
jobid = flags.jobid
num_workers = flags.num_workers
utils.print_out("# Job id %d" % jobid)
# GPU device
utils.print_out(
"# Devices visible to TensorFlow: %s" % repr(tf.Session().list_devices()))
# Random
random_seed = flags.random_seed
if random_seed is not None and random_seed > 0:
utils.print_out("# Set random seed to %d" % random_seed)
random.seed(random_seed + jobid)
np.random.seed(random_seed + jobid)
# Model output directory
out_dir = flags.out_dir
if out_dir and not tf.gfile.Exists(out_dir):
utils.print_out("# Creating output directory %s ..." % out_dir)
tf.gfile.MakeDirs(out_dir)
# Load hparams.
loaded_hparams = False
if flags.ckpt: # Try to load hparams from the same directory as ckpt
ckpt_dir = os.path.dirname(flags.ckpt)
ckpt_hparams_file = os.path.join(ckpt_dir, "hparams")
if tf.gfile.Exists(ckpt_hparams_file) or flags.hparams_path:
hparams = create_or_load_hparams(
ckpt_dir, default_hparams, flags.hparams_path,
save_hparams=False)
loaded_hparams = True
if not loaded_hparams: # Try to load from out_dir
assert out_dir
hparams = create_or_load_hparams(
out_dir, default_hparams, flags.hparams_path,
save_hparams=(jobid == 0))
## Train / Decode
if flags.inference_input_file:
# Inference output directory
trans_file = flags.inference_output_file
assert trans_file
trans_dir = os.path.dirname(trans_file)
if not tf.gfile.Exists(trans_dir): tf.gfile.MakeDirs(trans_dir)
# Inference indices
hparams.inference_indices = None
if flags.inference_list:
(hparams.inference_indices) = (
[int(token) for token in flags.inference_list.split(",")])
# Inference
ckpt = flags.ckpt
if not ckpt:
ckpt = tf.train.latest_checkpoint(out_dir)
inference_fn(ckpt, flags.inference_input_file,
trans_file, hparams, num_workers, jobid)
# Evaluation
ref_file = flags.inference_ref_file
if ref_file and tf.gfile.Exists(trans_file):
for metric in hparams.metrics:
score = evaluation_utils.evaluate(
ref_file,
trans_file,
metric,
hparams.subword_option)
utils.print_out(" %s: %.1f" % (metric, score))
else:
# Train
train_fn(hparams, target_session=target_session)