in nmt/inference.py [0:0]
def multi_worker_inference(sess,
infer_model,
loaded_infer_model,
inference_input_file,
inference_output_file,
hparams,
num_workers,
jobid):
"""Inference using multiple workers."""
assert num_workers > 1
final_output_infer = inference_output_file
output_infer = "%s_%d" % (inference_output_file, jobid)
output_infer_done = "%s_done_%d" % (inference_output_file, jobid)
# Read data
infer_data = load_data(inference_input_file, hparams)
# Split data to multiple workers
total_load = len(infer_data)
load_per_worker = int((total_load - 1) / num_workers) + 1
start_position = jobid * load_per_worker
end_position = min(start_position + load_per_worker, total_load)
infer_data = infer_data[start_position:end_position]
with infer_model.graph.as_default():
sess.run(infer_model.iterator.initializer,
{
infer_model.src_placeholder: infer_data,
infer_model.batch_size_placeholder: hparams.infer_batch_size
})
# Decode
utils.print_out("# Start decoding")
nmt_utils.decode_and_evaluate(
"infer",
loaded_infer_model,
sess,
output_infer,
ref_file=None,
metrics=hparams.metrics,
subword_option=hparams.subword_option,
beam_width=hparams.beam_width,
tgt_eos=hparams.eos,
num_translations_per_input=hparams.num_translations_per_input,
infer_mode=hparams.infer_mode)
# Change file name to indicate the file writing is completed.
tf.gfile.Rename(output_infer, output_infer_done, overwrite=True)
# Job 0 is responsible for the clean up.
if jobid != 0: return
# Now write all translations
with codecs.getwriter("utf-8")(
tf.gfile.GFile(final_output_infer, mode="wb")) as final_f:
for worker_id in range(num_workers):
worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id)
while not tf.gfile.Exists(worker_infer_done):
utils.print_out(" waiting job %d to complete." % worker_id)
time.sleep(10)
with codecs.getreader("utf-8")(
tf.gfile.GFile(worker_infer_done, mode="rb")) as f:
for translation in f:
final_f.write("%s" % translation)
for worker_id in range(num_workers):
worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id)
tf.gfile.Remove(worker_infer_done)