def train()

in research/a2n/train.py [0:0]


def train():
  """Running the main training loop with given parameters."""
  if FLAGS.task == 0 and not tf.gfile.Exists(FLAGS.output_dir):
    tf.gfile.MakeDirs(FLAGS.output_dir)

  # Read train/dev/test graphs, create datasets and model
  add_inverse_edge = FLAGS.model in \
                     ["source_rel_attention", "source_path_attention"]
  train_graph, train_data = read_graph_data(
      kg_file=FLAGS.kg_file,
      add_reverse_graph=not add_inverse_edge,
      add_inverse_edge=add_inverse_edge,
      mode="train",
      num_epochs=FLAGS.num_epochs, batchsize=FLAGS.batchsize,
      max_neighbors=FLAGS.max_neighbors,
      max_negatives=FLAGS.max_negatives,
      text_kg_file=FLAGS.text_kg_file
  )

  worker_device = "/job:{}".format(FLAGS.brain_job_name)
  with tf.device(
      tf.train.replica_device_setter(
          FLAGS.ps_tasks, worker_device=worker_device)):
    iterator = train_data.dataset.make_one_shot_iterator()
    candidate_scores, _, labels, model, is_train_ph, _ = create_model(
        train_graph, iterator
    )

  # Create train loss and training op
  loss = losses.softmax_crossentropy(logits=candidate_scores, labels=labels)
  optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
  global_step = tf.Variable(0, name="global_step", trainable=False)
  train_op = get_train_op(loss, optimizer, FLAGS.grad_clip,
                          global_step=global_step)
  tf.summary.scalar("Loss", loss)

  run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True)
  session_config = tf.ConfigProto(log_device_placement=True)

  # Create tf training session
  scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=1000))
  # ckpt_hook = tf.train.CheckpointSaverHook(
  #     checkpoint_dir=FLAGS.output_dir, scaffold=scaffold,
  #     save_steps=FLAGS.save_every
  # )
  # summary_hook = tf.train.SummarySaverHook(
  #     save_secs=60, output_dir=FLAGS.output_dir,
  #     summary_op=tf.summary.merge_all()
  # )
  session = tf.train.MonitoredTrainingSession(
      master=FLAGS.master,
      is_chief=(FLAGS.task == 0),
      checkpoint_dir=FLAGS.output_dir,
      save_checkpoint_steps=FLAGS.save_every,
      scaffold=scaffold,
      save_summaries_secs=60,
      # hooks=[summary_hook],
      # chief_only_hooks=[ckpt_hook],
      config=session_config
  )

  # Create embeddings visualization
  if FLAGS.task == 0:
    utils.save_embedding_vocabs(FLAGS.output_dir, train_graph,
                                FLAGS.entity_names_file)
    pconfig = projector.ProjectorConfig()
    add_embedding_to_projector(
        pconfig, model["entity_encoder"].embeddings.name.split(":")[0],
        os.path.join(FLAGS.output_dir, "entity_vocab.tsv")
    )
    add_embedding_to_projector(
        pconfig, model["relation_encoder"].embeddings.name.split(":")[0],
        os.path.join(FLAGS.output_dir, "relation_vocab.tsv")
    )
    if FLAGS.text_kg_file:
      word_embeddings = model["text_encoder"].word_embedding_encoder.embeddings
      add_embedding_to_projector(
          pconfig, word_embeddings.name.split(":")[0],
          os.path.join(FLAGS.output_dir, "word_vocab.tsv")
      )
    projector.visualize_embeddings(
        SummaryWriterCache.get(FLAGS.output_dir), pconfig
    )

  # Main training loop
  running_total_loss = 0.
  nsteps = 0
  gc.collect()
  while True:
    try:
      current_loss, _, _ = session.run(
          [loss, train_op, global_step],
          # feed_dict={is_train_ph: True, handle: train_iterator_handle},
          feed_dict={is_train_ph: True},
          options=run_options
      )
      nsteps += 1
      running_total_loss += current_loss
      tf.logging.info("Step %d, loss: %.3f, running avg loss: %.3f",
                      nsteps, current_loss, running_total_loss / nsteps)
      if nsteps %2 == 0:
        gc.collect()
    except tf.errors.OutOfRangeError:
      tf.logging.info("End of Traning Epochs after %d steps", nsteps)
      break