in research/a2n/train.py [0:0]
def train():
"""Running the main training loop with given parameters."""
if FLAGS.task == 0 and not tf.gfile.Exists(FLAGS.output_dir):
tf.gfile.MakeDirs(FLAGS.output_dir)
# Read train/dev/test graphs, create datasets and model
add_inverse_edge = FLAGS.model in \
["source_rel_attention", "source_path_attention"]
train_graph, train_data = read_graph_data(
kg_file=FLAGS.kg_file,
add_reverse_graph=not add_inverse_edge,
add_inverse_edge=add_inverse_edge,
mode="train",
num_epochs=FLAGS.num_epochs, batchsize=FLAGS.batchsize,
max_neighbors=FLAGS.max_neighbors,
max_negatives=FLAGS.max_negatives,
text_kg_file=FLAGS.text_kg_file
)
worker_device = "/job:{}".format(FLAGS.brain_job_name)
with tf.device(
tf.train.replica_device_setter(
FLAGS.ps_tasks, worker_device=worker_device)):
iterator = train_data.dataset.make_one_shot_iterator()
candidate_scores, _, labels, model, is_train_ph, _ = create_model(
train_graph, iterator
)
# Create train loss and training op
loss = losses.softmax_crossentropy(logits=candidate_scores, labels=labels)
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
global_step = tf.Variable(0, name="global_step", trainable=False)
train_op = get_train_op(loss, optimizer, FLAGS.grad_clip,
global_step=global_step)
tf.summary.scalar("Loss", loss)
run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True)
session_config = tf.ConfigProto(log_device_placement=True)
# Create tf training session
scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=1000))
# ckpt_hook = tf.train.CheckpointSaverHook(
# checkpoint_dir=FLAGS.output_dir, scaffold=scaffold,
# save_steps=FLAGS.save_every
# )
# summary_hook = tf.train.SummarySaverHook(
# save_secs=60, output_dir=FLAGS.output_dir,
# summary_op=tf.summary.merge_all()
# )
session = tf.train.MonitoredTrainingSession(
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
checkpoint_dir=FLAGS.output_dir,
save_checkpoint_steps=FLAGS.save_every,
scaffold=scaffold,
save_summaries_secs=60,
# hooks=[summary_hook],
# chief_only_hooks=[ckpt_hook],
config=session_config
)
# Create embeddings visualization
if FLAGS.task == 0:
utils.save_embedding_vocabs(FLAGS.output_dir, train_graph,
FLAGS.entity_names_file)
pconfig = projector.ProjectorConfig()
add_embedding_to_projector(
pconfig, model["entity_encoder"].embeddings.name.split(":")[0],
os.path.join(FLAGS.output_dir, "entity_vocab.tsv")
)
add_embedding_to_projector(
pconfig, model["relation_encoder"].embeddings.name.split(":")[0],
os.path.join(FLAGS.output_dir, "relation_vocab.tsv")
)
if FLAGS.text_kg_file:
word_embeddings = model["text_encoder"].word_embedding_encoder.embeddings
add_embedding_to_projector(
pconfig, word_embeddings.name.split(":")[0],
os.path.join(FLAGS.output_dir, "word_vocab.tsv")
)
projector.visualize_embeddings(
SummaryWriterCache.get(FLAGS.output_dir), pconfig
)
# Main training loop
running_total_loss = 0.
nsteps = 0
gc.collect()
while True:
try:
current_loss, _, _ = session.run(
[loss, train_op, global_step],
# feed_dict={is_train_ph: True, handle: train_iterator_handle},
feed_dict={is_train_ph: True},
options=run_options
)
nsteps += 1
running_total_loss += current_loss
tf.logging.info("Step %d, loss: %.3f, running avg loss: %.3f",
nsteps, current_loss, running_total_loss / nsteps)
if nsteps %2 == 0:
gc.collect()
except tf.errors.OutOfRangeError:
tf.logging.info("End of Traning Epochs after %d steps", nsteps)
break