in mesh_tensorflow/bert/run_squad.py [0:0]
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
bert_config = bert_lib.BertConfig.from_json_file(FLAGS.bert_config_file)
validate_flags_or_throw(bert_config)
tf.gfile.MakeDirs(FLAGS.output_dir)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
tpu_cluster_resolver = None
if FLAGS.use_tpu and FLAGS.tpu_name:
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
run_config = tf.estimator.tpu.RunConfig(
cluster=tpu_cluster_resolver,
master=FLAGS.master,
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
tpu_config=tf.estimator.tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
num_cores_per_replica=1,
per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig
.BROADCAST))
train_examples = None
num_train_steps = None
num_warmup_steps = None
if FLAGS.do_train:
train_examples = read_squad_examples(
input_file=FLAGS.train_file, is_training=True)
num_train_steps = int(
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
# Pre-shuffle the input to avoid having to make a very large shuffle
# buffer in in the `input_fn`.
rng = random.Random(12345)
rng.shuffle(train_examples)
model_fn = model_fn_builder(
bert_config=bert_config,
init_checkpoint=FLAGS.init_checkpoint,
learning_rate=FLAGS.learning_rate,
num_train_steps=num_train_steps,
num_warmup_steps=num_warmup_steps,
use_tpu=FLAGS.use_tpu)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator = tf.estimator.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=model_fn,
config=run_config,
train_batch_size=FLAGS.train_batch_size,
predict_batch_size=FLAGS.predict_batch_size)
if FLAGS.do_train:
# We write to a temporary file to avoid storing very large constant tensors
# in memory.
if FLAGS.cached_train_file:
train_tfrecords_file = FLAGS.cached_train_file
else:
train_writer = FeatureWriter(
filename=os.path.join(FLAGS.output_dir, "train.tf_record"),
is_training=True)
convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=FLAGS.max_seq_length,
doc_stride=FLAGS.doc_stride,
max_query_length=FLAGS.max_query_length,
is_training=True,
output_fn=train_writer.process_feature)
train_writer.close()
train_tfrecords_file = train_writer.filename
tf.logging.info("***** Running training *****")
tf.logging.info(" Num orig examples = %d", len(train_examples))
tf.logging.info(" Num split examples = %d", train_writer.num_features)
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
tf.logging.info(" Num steps = %d", num_train_steps)
del train_examples
train_input_fn = input_fn_builder(
input_file=train_tfrecords_file,
seq_length=FLAGS.max_seq_length,
is_training=True,
drop_remainder=True)
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
if FLAGS.do_predict:
eval_examples = read_squad_examples(
input_file=FLAGS.predict_file, is_training=False)
eval_writer = FeatureWriter(
filename=os.path.join(FLAGS.output_dir, "eval.tf_record"),
is_training=False)
eval_features = []
def append_feature(feature):
eval_features.append(feature)
eval_writer.process_feature(feature)
convert_examples_to_features(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=FLAGS.max_seq_length,
doc_stride=FLAGS.doc_stride,
max_query_length=FLAGS.max_query_length,
is_training=False,
output_fn=append_feature)
eval_writer.close()
tf.logging.info("***** Running predictions *****")
tf.logging.info(" Num orig examples = %d", len(eval_examples))
tf.logging.info(" Num split examples = %d", len(eval_features))
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
all_results = []
predict_input_fn = input_fn_builder(
input_file=eval_writer.filename,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=False)
# If running eval on the TPU, you will need to specify the number of
# steps.
all_results = []
for result in estimator.predict(
predict_input_fn, yield_single_examples=True):
if len(all_results) % 1000 == 0:
tf.logging.info("Processing example: %d" % (len(all_results)))
unique_id = int(result["unique_ids"])
start_logits = [float(x) for x in result["start_logits"].flat]
end_logits = [float(x) for x in result["end_logits"].flat]
all_results.append(
RawResult(
unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json")
output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json")
output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json")
write_predictions(eval_examples, eval_features, all_results,
FLAGS.n_best_size, FLAGS.max_answer_length,
FLAGS.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file)