def main()

in tensorflow_graphics/projects/nasa/track.py [0:0]


def main(unused_argv):
  tf.random.set_random_seed(20200823)
  np.random.seed(20200823)

  input_fn = datasets.get_dataset("test", FLAGS)
  batch = input_fn(None).make_one_shot_iterator().get_next()

  # Extracting a motion sequence in a data dict
  data = {}
  with tf.Session() as sess:
    while True:
      try:
        batch_val = sess.run(batch)
        key = batch_val["name"][0]
        data[key] = batch_val
        data[key]["vert"] = (
            data[key]["vert"] +
            np.random.normal(0, 5e-3, data[key]["vert"].shape))
      except tf.errors.OutOfRangeError:
        break
    sorted_keys = sorted(data.keys())

  # Parse relevant parameters for theta optimization
  trans_range = FLAGS.trans_range
  n_dims = FLAGS.n_dims
  n_translate = 3 if n_dims == 3 else 2
  n_rotate = 6 if n_dims == 3 else 2

  # Set up parameters and place holders.
  tf.reset_default_graph()
  accum_mat_holder = tf.placeholder(tf.float32,
                                    [FLAGS.n_parts, n_dims + 1, n_dims + 1])
  pt_holder = tf.placeholder(tf.float32, [1, 1, None, FLAGS.n_dims])
  weight_holder = tf.placeholder(tf.float32, [1, 1, None, FLAGS.n_parts])
  loss_holder = tf.placeholder(tf.float32, [])
  glue_loss_holder = tf.placeholder(tf.float32, [])
  iou_holder = tf.placeholder(tf.float32, [])
  id_transform = model_utils.get_identity_transform(n_translate, n_rotate,
                                                    FLAGS.n_parts)
  theta = tf.Variable(id_transform, trainable=True, name="pose_var")

  # Compute transformation matrix and joints according to theta
  temp_mat = model_utils.get_transform_matrix(theta, trans_range, n_translate,
                                              n_rotate, n_dims)
  if FLAGS.left_trans:
    trans_mat = tf.matmul(
        tf.reshape(accum_mat_holder,
                   [1, 1, FLAGS.n_parts, n_dims + 1, n_dims + 1]),
        tf.reshape(temp_mat, [1, 1, FLAGS.n_parts, n_dims + 1, n_dims + 1]))
  else:
    trans_mat = tf.matmul(
        tf.reshape(temp_mat, [1, 1, FLAGS.n_parts, n_dims + 1, n_dims + 1]),
        tf.reshape(accum_mat_holder,
                   [1, 1, FLAGS.n_parts, n_dims + 1, n_dims + 1]))
  r = trans_mat[..., :n_dims, :n_dims]
  t = trans_mat[..., :n_dims, -1:]
  r_t = tf.transpose(r, [0, 1, 2, 4, 3])
  t_0 = -tf.matmul(r_t, t)
  joint_trans = tf.concat(
      [tf.concat([r_t, t_0], axis=-1), trans_mat[..., -1:, :]], axis=-2)
  joint_trans = tf.reshape(joint_trans, [FLAGS.n_parts, n_dims + 1, n_dims + 1])
  inv_first_frame_trans = data[sorted_keys[0]]["transform"].reshape(
      [FLAGS.n_parts, n_dims + 1, n_dims + 1])
  joint_trans = tf.matmul(joint_trans, inv_first_frame_trans)
  first_frame_joint = data[sorted_keys[0]]["joint"].reshape(
      [FLAGS.n_parts, n_dims, 1])
  first_frame_joint = tf.concat(
      [first_frame_joint,
       tf.ones_like(first_frame_joint[..., :1, :])], axis=-2)
  joint = tf.matmul(joint_trans, first_frame_joint)[..., :-1, 0]

  if FLAGS.glue_w > 0.:
    with tf.io.gfile.GFile(FLAGS.joint_data, "rb") as cin:
      connectivity = np.load(cin)
    end_points = data[sorted_keys[0]]["joint"].reshape([FLAGS.n_parts, n_dims])
    first_frame_trans = data[sorted_keys[0]]["transform"].reshape(
        [FLAGS.n_parts, n_dims + 1, n_dims + 1])
    glue_loss = utils.compute_glue_loss(
        connectivity, end_points,
        tf.reshape(trans_mat, [FLAGS.n_parts, n_dims + 1, n_dims + 1]),
        first_frame_trans, joint, FLAGS)
  else:
    glue_loss = tf.constant(0, dtype=tf.float32)

  # Set up computation graph
  model_fn = models.get_model(FLAGS)
  batch_holder = {
      "transform": trans_mat,
      "joint": joint,
      "point": pt_holder,
      "weight": weight_holder,
  }
  if FLAGS.gradient_type == "vanilla":
    interface = utils.vanilla_theta_gradient(model_fn, batch_holder, FLAGS)
  elif FLAGS.gradient_type == "reparam":
    interface = utils.reparam_theta_gradient(model_fn, batch_holder, FLAGS)

  # Parse content of the interface
  latent_holder, latent, occ, rec_loss = interface
  if FLAGS.glue_w > 0:
    loss = rec_loss + glue_loss * FLAGS.glue_w
  else:
    loss = rec_loss
  global_step = tf.train.get_or_create_global_step()
  optimizer = tf.train.AdamOptimizer(FLAGS.theta_lr)
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(
        loss,
        var_list=[theta],
        global_step=global_step,
        name="optimize_theta",
    )
    reset_op = tf.variables_initializer(
        [theta, global_step] + optimizer.variables(), name="reset_button")

  tf.summary.scalar("Loss", loss_holder)
  tf.summary.scalar("IoU", iou_holder)
  tf.summary.scalar("Glue", glue_loss_holder)
  summary_op = tf.summary.merge_all()

  # Load checkpoint and run optimization
  assignment_map = {
      "shape/": "shape/",
  }
  tf.train.init_from_checkpoint(FLAGS.train_dir, assignment_map)
  init_op = tf.global_variables_initializer()
  with tf.summary.FileWriter(FLAGS.train_dir) as summary_writer:
    with tf.Session() as sess:
      sess.run(init_op)
      accum_mat = data[sorted_keys[0]]["transform"].reshape(
          [FLAGS.n_parts, n_dims + 1, n_dims + 1])
      accum_iou = 0.
      example_cnt = 0
      for frame_id, k in enumerate(sorted_keys):
        data_example = data[k]
        feed_dict = {
            pt_holder: data_example["vert"],
            weight_holder: data_example["weight"],
            accum_mat_holder: accum_mat,
        }
        loss_val, loss_glue_val = utils.optimize_theta(feed_dict, loss,
                                                       reset_op, train_op,
                                                       rec_loss, glue_loss,
                                                       sess, k, FLAGS)
        iou = utils.compute_iou(sess, feed_dict, latent_holder, pt_holder,
                                latent, occ[:, -1:], data_example["points"],
                                data_example["labels"], FLAGS)
        accum_iou += iou
        example_cnt += 1
        utils.save_mesh(
            sess,
            feed_dict,
            latent_holder,
            pt_holder,
            latent,
            occ,
            data_example,
            FLAGS,
            pth="tracked_{}".format(FLAGS.gradient_type))
        utils.save_pointcloud(
            data_example,
            FLAGS,
            pth="pointcloud_{}".format(FLAGS.gradient_type))

        summary = sess.run(summary_op, {
            loss_holder: loss_val,
            iou_holder: iou,
            glue_loss_holder: loss_glue_val
        })
        summary_writer.add_summary(summary, frame_id)
        summary_writer.flush()

        temp_mat_val = sess.run(temp_mat)
        if FLAGS.left_trans:
          accum_mat = np.matmul(
              accum_mat,
              temp_mat_val.reshape([FLAGS.n_parts, n_dims + 1, n_dims + 1]))
        else:
          accum_mat = np.matmul(
              temp_mat_val.reshape([FLAGS.n_parts, n_dims + 1, n_dims + 1]),
              accum_mat)

      with tf.io.gfile.GFile(
          path.join(FLAGS.train_dir, "tracked_{}".format(FLAGS.gradient_type),
                    "iou.txt"), "w") as iout:
        iout.write("{}\n".format(accum_iou / example_cnt))