def eval_florence()

in c3dm/dataset/eval_zoo.py [0:0]


def eval_florence(cached_preds, eval_vars=None, TGT_NIMS=1427, visualize=False):
	from tools.pcl_unproject import depth2pcl
	from tools.eval_functions import eval_pcl_icp

	root = IMAGE_ROOTS['celeba_ff'][1]
	
	nim = len(cached_preds['mesh_path'])

	errs = []

	for imi in tqdm(range(nim)):

		# if imi <= 775:
		# 	continue

		# get the ff mesh
		mesh_path = cached_preds['mesh_path'][imi]
		if len(mesh_path)==0: continue
		mesh_path = os.path.join(root, mesh_path)
		vertices, faces = load_ff_obj(mesh_path)
		mesh_gt = trimesh.Trimesh(
			vertices=vertices.tolist(),
			faces=faces.tolist()
		)

		# get our prediction
		kp_loc     = cached_preds['kp_loc'][imi]
		# image_size = list(cached_preds['images'][imi].shape[1:])
		mask       = cached_preds['masks'][imi]

		if mask.sum()<=1:
			print('Empty mask!!!')	
			continue

		image_size = list(mask.shape[1:])
		# mask       = Fu.interpolate(mask[None], size=image_size)[0]
		pcl_pred   = cached_preds['shape_image_coord_best_scale'][imi]
		pcl_pred   = Fu.interpolate(pcl_pred[None], size=image_size)[0]

		err_now = {}
		for flip in (True, False):
			pcl_pred_now = pcl_pred.clone()
			if flip: pcl_pred_now[2,:] = -pcl_pred_now[2,:]
			# compute icp error
			err = eval_pcl_icp(pcl_pred_now, mesh_gt, mask, kp_loc)
			err = {
				'EVAL_pcl_scl_recut_orthographic': err['dist_pcl_scl_recut'],
				'EVAL_pcl_scl_orthographic':       err['dist_pcl_scl'],
				'EVAL_pcl_orthographic':           err['dist_pcl'],
			}
			if flip: err = {k+'_flip':v for k, v in err.items()}
			err_now.update(err)
			
		errs.append(err_now)

		print('<EVAL_STATE>')
		print(f'IMAGE={imi}')
		print(err_now)
		print('<\EVAL_STATE>')

	results = {}
	for med in (True, False):
		for k in errs[0]:
			res = torch.FloatTensor([float(err[k]) for err in errs])
			if med:
				res = float(res.median())
			else:
				res = float(res.mean())
			results[(k+'_med') if med else k] = res
				
	print('Florence Face evaluation results:')
	for k, v in results.items():
		print('%20s: %1.5f' % (k,v) )

	if eval_vars is not None:
		for eval_var in eval_vars:
			assert eval_var in results, \
				'evaluation variable missing! (%s)' % str(eval_var)
		print('eval vars check ok!')

	return results, None