def visualize()

in c3dm/model.py [0:0]


	def visualize( self, visdom_env_imgs, trainmode, \
						preds, stats, clear_env=False ):
		if stats is not None:
			it = stats.it[trainmode]
			epoch = stats.epoch
			viz = vis_utils.get_visdom_connection(
				server=stats.visdom_server,
				port=stats.visdom_port,
			)
		else:
			it = 0
			epoch = 0
			viz = vis_utils.get_visdom_connection()
		
		if not viz.check_connection():
			print("no visdom server! -> skipping batch vis")
			return

		idx_image = 0

		title="e%d_it%d_im%d"%(epoch,it,idx_image)

		imvar = 'images_aug' if 'images_aug' in preds else 'images'
		dvar  = 'depths_aug' if 'depths_aug' in preds else 'depths'
		mvar  = 'masks_aug' if 'masks_aug' in preds else 'masks'

		# show depth
		ds  = preds['depth_dense'].cpu().detach().repeat(1,3,1,1)
		ims = preds[imvar].cpu().detach()
		ims = Fu.interpolate(ims,size=ds.shape[2:])
		if mvar in preds: # mask depths, ims by masks
			masks = Fu.interpolate(preds[mvar].cpu().detach(),
								   size=ds.shape[2:], mode='nearest' )
			ims *= masks ; ds *= masks
		ds = vis_utils.denorm_image_trivial(ds)
		if 'pred_mask' in preds:
			pred_mask = torch.sigmoid(preds['pred_mask'][:, None].detach()).cpu().expand_as(ims)
			ims_ds = torch.cat( (ims, ds, pred_mask), dim=2 )
		else:
			ims_ds = torch.cat( (ims, ds), dim=2 )
		viz.images(ims_ds, env=visdom_env_imgs, opts={'title':title}, win='depth')

		# show aug images if present
		imss = []
		for k in (imvar, 'images_app', 'images_geom'):
			if k in preds:
				ims = preds[k].cpu().detach()
				ims = Fu.interpolate(ims, scale_factor=0.25)
				ims = vis_utils.denorm_image_trivial(ims)
				R, R_gt = preds['phi']['R'], preds['nrsfm']['phi']['R']
				angle_to_0 = np.rad2deg(
					so3.so3_relative_angle(R[0].expand_as(R), R).data.cpu().numpy()
				)
				angle_to_0_gt = np.rad2deg(
					so3.so3_relative_angle(R_gt[0].expand_as(R_gt), R_gt).data.cpu().numpy()
				)
				if ~np.isnan(angle_to_0).any():
					ims = np.stack([
						vis_utils.write_into_image(
							(im*255.).astype(np.uint8), "%d° / %d°" % (d, d_gt), color=(255,0,255)
						) for im, d, d_gt in zip(ims.data.numpy(), angle_to_0, angle_to_0_gt)
					])
				else:
					ims = (ims.data.numpy()*255.).astype(np.uint8)
				imss.append(ims)
		if len(imss) > 0:
			viz.images(
				#torch.cat(imss, dim=2), 
				np.concatenate(imss, axis=2).astype(np.float32)/255.,
				env=visdom_env_imgs, 
				opts={'title': title}, 
				win='imaug',
			)
	
		# show reprojections
		p1 = preds['kp_loc_aug' if 'kp_loc_aug' in preds else 'kp_loc'][idx_image]
		p2 = preds['kp_reprojected_image'][idx_image,0:2]
		p3 = preds['nrsfm']['kp_reprojected_image'][idx_image]
		p = np.stack([p_.detach().cpu().numpy() for p_ in (p1, p2, p3)])
		v = preds['kp_vis'][idx_image].detach().cpu().numpy()
		vis_utils.show_projections( viz, visdom_env_imgs, p, v=v, 
					title=title, cmap__='rainbow', 
					markersize=50, sticks=None, 
					stickwidth=1, plot_point_order=False,
					image=preds[imvar][idx_image].detach().cpu().numpy(),
					win='projections' )

		# dense reprojections
		p1 = preds['image_repro_gt'].detach().cpu()
		p2 = preds['shape_reprojected_image'][idx_image].detach().cpu()
		# override mask with downsampled (augmentation applied if any)
		mvar = 'embed_masks'
		if mvar in preds:
			masks = preds[mvar].detach().cpu()
			#masks = Fu.interpolate(masks, size=p2.shape[1:], mode='nearest')
			p1 = p1 * masks[idx_image]
			p2 = p2 * masks[idx_image]

		# TEMP
		img = (preds[imvar][idx_image].cpu() * Fu.interpolate(
			preds[mvar].cpu()[idx_image:idx_image+1], size=preds[imvar][0, 0].size(), mode='nearest'
		)[0]).data.cpu().numpy()
		p = np.stack([p_.view(2,-1).numpy() for p_ in (p1, p2)])
		vis_utils.show_projections( viz, visdom_env_imgs, p, v=None, 
					title=title, cmap__='rainbow', 
					markersize=1, sticks=None, 
					stickwidth=1, plot_point_order=False,
					image=img,
					win='projections_dense' )
		vis_utils.show_flow(viz, visdom_env_imgs, p,
			image=preds[imvar][idx_image].detach().cpu().numpy(),
			title='flow ' + title,
			linewidth=1,
			win='projections_flow',
		)

		if 'sph_sample_projs' in preds:
			p = preds['sph_sample_projs'][idx_image].detach().cpu().view(2, -1)
			if 'sph_sample_gt' in preds:
				p_ = preds['sph_sample_gt'][idx_image].detach().cpu().view(2, -1)
				p_ = p_.repeat(1, math.ceil(p.shape[1]/p_.shape[1]))
				p = [p, p_[:, :p.shape[1]]]
			else:
				p = [p.view(2, -1)]
			# p = (torch.stack(p) + 1.) / 2.
			p = (torch.stack(p) + 1.) / 2.
			imsize = preds[imvar][idx_image].shape[1:] 
			p[:, 0, :] *= imsize[1]
			p[:, 1, :] *= imsize[0]
			vis_utils.show_projections(viz, visdom_env_imgs, 
				p, v=None,
				title=title + '_spl_sil', 
				cmap__='rainbow', 
				markersize=1, sticks=None, 
				stickwidth=1, plot_point_order=False,
				image=preds[imvar][idx_image].detach().cpu().numpy(),
				win='projections_spl_sil'
			)

		merged_embed = self._merge_masked_tensors(
			preds['embed_full'], preds['embed_masks']
		)[..., None]
		gl_desc_0 = {k: v[:1] for k, v in preds['phi'].items()}
		merged_with_pivot_phis = self._get_shapes_and_projections(
			merged_embed, None, gl_desc_0, preds['K'][:1]
		)
		preds['shape_canonical_same_alphas'] = merged_with_pivot_phis[
			'shape_canonical_dense'
		][0 ,..., 0]

		# dense 3d
		pcl_show = {}
		vis_list = ['dense3d', 'mean_shape', 'embed_db', 'batch_fused', 'sph_embed']
		if self.loss_weights['loss_sph_sample_mask'] > 0:
			vis_list.append('sph_sample_3d')

		for vis in vis_list:
			if vis=='canonical':
				pcl = preds['shape_canonical_dense']
			elif vis=='dense3d':
				pcl = preds['shape_image_coord_cal_dense']
			elif vis=='batch_fused':
				pcl = preds['shape_canonical_same_alphas'].detach().cpu()
				pcl = torch.cat((pcl, pcl), dim=0)
				pcl[3:5,:] = 0.0
				pcl[5,:] = 1.0
			elif vis=='mean_shape':
				pcl = preds['embed_mean']
			elif vis=='mean_c3dpo_shape':
				pcl = preds['nrsfm_mean_shape']
			elif vis=='shape_canonical':
				pcl = preds['shape_canonical_dense']
			elif vis == 'sph_embed':
				pcl = preds['embed'].detach().clone()
			elif vis == 'sph_sample_3d':
				pcl = preds['sph_sample_3d'][idx_image].detach().cpu().view(3, -1)
				pcl = torch.cat((pcl, pcl.clone()), dim=0)
				pcl[4:,:] = 0.0
				pcl[3,:] = 1.0
				# filtering outliers
				pcl[:3] -= pcl[:3].mean(dim=1, keepdim=True) # will be centered anyway
				std = pcl[:3].std(dim=1).max()
				pcl[:3] = pcl[:3].clamp(-2.5*std, 2.5*std)
			elif vis == 'embed_db':
				pcl = self.embed_db.get_db(uniform_sphere=False).cpu().detach().view(3, -1)
				pcl = torch.cat((pcl, pcl.clone()), dim=0)
				pcl[3:5,:] = 0.0
				pcl[4,:] = 1.0
			else:
				raise ValueError(vis)

			if vis not in ('mean_c3dpo_shape', 'batch_fused', 'sph_sample_3d', 'embed_db'):
				pcl_rgb = preds[imvar].detach().cpu()
				#pcl = Fu.interpolate(pcl.detach().cpu(), pcl_rgb.shape[2:], mode='bilinear')
				pcl_rgb = Fu.interpolate(pcl_rgb, size=pcl.shape[2:], mode='bilinear')
				if (mvar in preds):
					masks = preds[mvar].detach().cpu()
					masks = Fu.interpolate(masks, \
						size=pcl.shape[2:], mode='nearest')
				else:
					masks = None
				
				pcl = pcl.detach().cpu()[idx_image].view(3,-1)
				pcl_rgb = pcl_rgb[idx_image].view(3,-1)
				pcl = torch.cat((pcl, pcl_rgb), dim=0)
				if masks is not None:
					masks = masks[idx_image].view(-1)
					pcl = pcl[:,masks>0.] 

				# if vis == 'sph_embed':
				# 	import pdb; pdb.set_trace()

			if pcl.numel()==0: 
				continue
			pcl_show[vis] = pcl.numpy()

		vis_utils.visdom_plotly_pointclouds(viz, pcl_show, visdom_env_imgs,
							title=title+'_'+vis,
							markersize=1,
							sticks=None, win=vis,
							height=700, width=700 ,
							normalise=True,
		)

		var3d = { 
			'orthographic': 'shape_image_coord',
			'perspective': 'shape_image_coord_cal',
		}[self.projection_type]
		sparse_pcl = {
			'nrsfm': preds['nrsfm'][var3d][idx_image].detach().cpu().numpy().copy(),
			'dense': preds['shape_image_coord_cal'][idx_image].detach().cpu().numpy().copy(),
		}
		if 'kp_loc_3d' in preds:
			sparse_pcl['gt'] = preds['kp_loc_3d'][idx_image].detach().cpu().numpy().copy()

		if 'class_mask' in preds:
			class_mask = preds['class_mask'][idx_image].detach().cpu().numpy()
			sparse_pcl = {k: v*class_mask[None] for k,v in sparse_pcl.items()}
		
		vis_utils.visdom_plotly_pointclouds(viz, sparse_pcl, visdom_env_imgs, \
								title=title+'_sparse3d', \
								markersize=5, \
								sticks=None, win='nrsfm_3d',
								height=500,
								width=500 )

		if 'photo_out' in preds and preds['photo_out'] is not None:
			# show the source images and their renders
			ims_src     = preds['photo_out']['images'].detach().cpu()
			ims_repro   = preds['photo_out']['images_reproject'].detach().cpu()
			ims_reenact = preds['photo_out']['images_reenact'].detach().cpu()
			ims_gt      = preds['photo_out']['images_gt'].detach().cpu()

			# cat all the images
			ims = torch.cat((ims_src,ims_reenact,ims_repro,ims_gt), dim=2)
			ims = torch.clamp(ims,0.,1.)
			viz.images(ims, env=visdom_env_imgs, opts={'title':title}, win='imrepro')
			
			im_renders = preds['photo_out']['image_ref_render']
			for l in im_renders:
				im_gt = preds['photo_out']['images_gt'][0].detach().cpu()
				im_render = im_renders[l].detach().cpu()
				im = torch.cat((im_gt, im_render), dim=2)
				im = torch.clamp(im, 0., 1.)
				viz.image(im, env=visdom_env_imgs, \
					opts={'title':title+'_min_render_%s' % l}, win='imrender_%s' % l)

		if 'app_out' in preds and preds['app_out'] is not None:
			# show the source images and their predictions
			ims_src  = preds['app_out']['images'].detach().cpu()
			ims_pred = preds['app_out']['images_pred_clamp'].detach().cpu()
			ims = torch.cat((ims_src,ims_pred), dim=2)
			viz.images(ims, env=visdom_env_imgs, opts={'title':title}, win='impred')