in tensorflow_graphics/projects/cvxnet/eval.py [0:0]
def main(unused_argv):
tf.set_random_seed(2191997)
np.random.seed(6281996)
logging.info('=> Starting ...')
eval_dir = path.join(FLAGS.train_dir, 'eval')
# Select dataset.
logging.info('=> Preparing datasets ...')
data = datasets.get_dataset(FLAGS.dataset, 'test', FLAGS)
batch = tf.data.make_one_shot_iterator(data).get_next()
# Select model.
logging.info('=> Creating {} model'.format(FLAGS.model))
model = models.get_model(FLAGS.model, FLAGS)
# Set up the graph
global_step = tf.train.get_or_create_global_step()
test_loss, test_iou = model.compute_loss(batch, training=False)
if FLAGS.extract_mesh or FLAGS.surface_metrics:
img_ch = 3 if FLAGS.image_input else FLAGS.depth_d
input_holder = tf.placeholder(tf.float32, [None, 224, 224, img_ch])
params = model.encode(input_holder, training=False)
params_holder = tf.placeholder(tf.float32, [None, model.n_params])
points_holder = tf.placeholder(tf.float32, [None, None, FLAGS.dims])
indicators, unused_var = model.decode(
params_holder, points_holder, training=False)
if (not FLAGS.extract_mesh) or (not FLAGS.surface_metrics):
summary_writer = tf.summary.FileWriter(eval_dir)
iou_holder = tf.placeholder(tf.float32)
iou_summary = tf.summary.scalar('test_iou', iou_holder)
logging.info('=> Evaluating ...')
last_step = -1
while True:
shapenet_stats = utils.init_stats()
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[],
save_checkpoint_steps=None,
save_checkpoint_secs=None,
save_summaries_steps=None,
save_summaries_secs=None,
log_step_count_steps=None,
max_wait_secs=3600) as mon_sess:
step_val = mon_sess.run(global_step)
if step_val <= last_step:
continue
else:
last_step = step_val
while not mon_sess.should_stop():
batch_val, unused_var, test_iou_val = mon_sess.run(
[batch, test_loss, test_iou])
if FLAGS.extract_mesh or FLAGS.surface_metrics:
if FLAGS.image_input:
input_val = batch_val['image']
else:
input_val = batch_val['depth']
mesh = utils.extract_mesh(
input_val,
params,
indicators,
input_holder,
params_holder,
points_holder,
mon_sess,
FLAGS,
)
if FLAGS.trans_dir is not None:
utils.transform_mesh(mesh, batch_val['name'], FLAGS.trans_dir)
if FLAGS.extract_mesh:
utils.save_mesh(mesh, batch_val['name'], eval_dir)
if FLAGS.surface_metrics:
chamfer, fscore = utils.compute_surface_metrics(
mesh, batch_val['name'], FLAGS.mesh_dir)
else:
chamfer = fscore = 0.
example_stats = utils.Stats(
iou=test_iou_val[0], chamfer=chamfer, fscore=fscore)
utils.update_stats(example_stats, batch_val['name'], shapenet_stats)
utils.average_stats(shapenet_stats)
if (not FLAGS.extract_mesh) and (not FLAGS.surface_metrics):
with tf.Session() as sess:
iou_summary_val = sess.run(
iou_summary, feed_dict={iou_holder: shapenet_stats['all']['iou']})
summary_writer.add_summary(iou_summary_val, step_val)
summary_writer.flush()
if FLAGS.surface_metrics:
utils.write_stats(
shapenet_stats,
eval_dir,
step_val,
)
if FLAGS.eval_once or step_val >= FLAGS.max_steps:
break