def main()

in tensorflow_graphics/projects/points_to_3Dobjects/train_multi_objects/train.py [0:0]


def main(_):

  if FLAGS.debug:
    FLAGS.split = 'train'
  if FLAGS.francis:
    FLAGS.split = 'francis'
  if 'scannet' in FLAGS.tfrecords_dir:
    FLAGS.num_classes = 8
    FLAGS.image_width = 640
    FLAGS.image_height = 640

  if not tf.io.gfile.exists(FLAGS.logdir):
    tf.io.gfile.makedirs(FLAGS.logdir)
  if FLAGS.train:
    train()
  elif FLAGS.val:
    if FLAGS.part_id == -1:

      def eval_iou():
        metrics_dir = os.path.join(FLAGS.logdir, FLAGS.metrics_dir, 'iou')
        if not tf.io.gfile.exists(metrics_dir):
          tf.io.gfile.makedirs(metrics_dir)

        while len(tf.io.gfile.listdir(metrics_dir)) < 100:
          print('waiting...',
                len(tf.io.gfile.listdir(metrics_dir)), 'out of 100')
          time.sleep(5)

        all_iou_per_class = {}
        for i, iou_file in enumerate(tf.io.gfile.listdir(metrics_dir)):
          logging.info(i)
          iou_file_path = os.path.join(metrics_dir, iou_file)
          with gfile.Open(iou_file_path, 'rb') as filename:
            print(iou_file_path)
            iou_per_class = pickle.load(filename)
            for k, v in iou_per_class.items():
              if k not in all_iou_per_class:
                all_iou_per_class[k] = []
              all_iou_per_class[k] = \
                  all_iou_per_class[k] + [n.numpy() for n in v]

        with gfile.Open(metrics_dir + '.txt', 'wb') as file:
          mean_iou_per_class = {}
          all_iou = []
          class_id_to_name = ['chair', 'sofa', 'table', 'bottle', 'bowl', 'mug']
          for k, v in all_iou_per_class.items():
            mean_iou_per_class[k] = np.mean(v)
            file.write(class_id_to_name[k]+':\t'+
                       str(np.mean(v))+' ('+str(np.std(v))+')\n')
            all_iou = all_iou + v
          per_class_mean = np.mean(list(mean_iou_per_class.values()))
          global_mean = np.mean(all_iou)
          file.write('\nmIoU:\t'+str(per_class_mean))
          file.write('\nglobal IoU:\t'+str(global_mean))

      def eval_collision():
        metrics_dir = os.path.join(FLAGS.logdir, FLAGS.metrics_dir, 'collision')
        if not tf.io.gfile.exists(metrics_dir):
          tf.io.gfile.makedirs(metrics_dir)

        while len(tf.io.gfile.listdir(metrics_dir)) < FLAGS.n_tfrecords:
          time.sleep(5)

        total_collisions = 0
        total_intersections = []
        total_ious = []
        for i, file in enumerate(tf.io.gfile.listdir(metrics_dir)):
          logging.info(i)
          file_path = os.path.join(metrics_dir, file)
          with gfile.Open(file_path, 'rb') as filename:
            collision_data = pickle.load(filename)
            total_collisions += np.sum(collision_data['collisions'])
            total_intersections = \
                total_intersections + collision_data['intersections']
            total_ious = total_ious + collision_data['ious']

        with gfile.Open(metrics_dir+'.txt', 'wb') as file:
          file.write('\ncollisions:\t'+str(total_collisions))
          file.write('\nintersect.:\t'+str(np.mean(total_intersections)))
          file.write('\niou:\t'+str(np.mean(total_ious)))

      eval_iou()
      eval_collision()

      return
    else:
      val(record_losses=FLAGS.record_val_losses,
          split=FLAGS.split,
          part_id=FLAGS.part_id)