in c3dm/model.py [0:0]
def visualize( self, visdom_env_imgs, trainmode, \
preds, stats, clear_env=False ):
if stats is not None:
it = stats.it[trainmode]
epoch = stats.epoch
viz = vis_utils.get_visdom_connection(
server=stats.visdom_server,
port=stats.visdom_port,
)
else:
it = 0
epoch = 0
viz = vis_utils.get_visdom_connection()
if not viz.check_connection():
print("no visdom server! -> skipping batch vis")
return
idx_image = 0
title="e%d_it%d_im%d"%(epoch,it,idx_image)
imvar = 'images_aug' if 'images_aug' in preds else 'images'
dvar = 'depths_aug' if 'depths_aug' in preds else 'depths'
mvar = 'masks_aug' if 'masks_aug' in preds else 'masks'
# show depth
ds = preds['depth_dense'].cpu().detach().repeat(1,3,1,1)
ims = preds[imvar].cpu().detach()
ims = Fu.interpolate(ims,size=ds.shape[2:])
if mvar in preds: # mask depths, ims by masks
masks = Fu.interpolate(preds[mvar].cpu().detach(),
size=ds.shape[2:], mode='nearest' )
ims *= masks ; ds *= masks
ds = vis_utils.denorm_image_trivial(ds)
if 'pred_mask' in preds:
pred_mask = torch.sigmoid(preds['pred_mask'][:, None].detach()).cpu().expand_as(ims)
ims_ds = torch.cat( (ims, ds, pred_mask), dim=2 )
else:
ims_ds = torch.cat( (ims, ds), dim=2 )
viz.images(ims_ds, env=visdom_env_imgs, opts={'title':title}, win='depth')
# show aug images if present
imss = []
for k in (imvar, 'images_app', 'images_geom'):
if k in preds:
ims = preds[k].cpu().detach()
ims = Fu.interpolate(ims, scale_factor=0.25)
ims = vis_utils.denorm_image_trivial(ims)
R, R_gt = preds['phi']['R'], preds['nrsfm']['phi']['R']
angle_to_0 = np.rad2deg(
so3.so3_relative_angle(R[0].expand_as(R), R).data.cpu().numpy()
)
angle_to_0_gt = np.rad2deg(
so3.so3_relative_angle(R_gt[0].expand_as(R_gt), R_gt).data.cpu().numpy()
)
if ~np.isnan(angle_to_0).any():
ims = np.stack([
vis_utils.write_into_image(
(im*255.).astype(np.uint8), "%d° / %d°" % (d, d_gt), color=(255,0,255)
) for im, d, d_gt in zip(ims.data.numpy(), angle_to_0, angle_to_0_gt)
])
else:
ims = (ims.data.numpy()*255.).astype(np.uint8)
imss.append(ims)
if len(imss) > 0:
viz.images(
#torch.cat(imss, dim=2),
np.concatenate(imss, axis=2).astype(np.float32)/255.,
env=visdom_env_imgs,
opts={'title': title},
win='imaug',
)
# show reprojections
p1 = preds['kp_loc_aug' if 'kp_loc_aug' in preds else 'kp_loc'][idx_image]
p2 = preds['kp_reprojected_image'][idx_image,0:2]
p3 = preds['nrsfm']['kp_reprojected_image'][idx_image]
p = np.stack([p_.detach().cpu().numpy() for p_ in (p1, p2, p3)])
v = preds['kp_vis'][idx_image].detach().cpu().numpy()
vis_utils.show_projections( viz, visdom_env_imgs, p, v=v,
title=title, cmap__='rainbow',
markersize=50, sticks=None,
stickwidth=1, plot_point_order=False,
image=preds[imvar][idx_image].detach().cpu().numpy(),
win='projections' )
# dense reprojections
p1 = preds['image_repro_gt'].detach().cpu()
p2 = preds['shape_reprojected_image'][idx_image].detach().cpu()
# override mask with downsampled (augmentation applied if any)
mvar = 'embed_masks'
if mvar in preds:
masks = preds[mvar].detach().cpu()
#masks = Fu.interpolate(masks, size=p2.shape[1:], mode='nearest')
p1 = p1 * masks[idx_image]
p2 = p2 * masks[idx_image]
# TEMP
img = (preds[imvar][idx_image].cpu() * Fu.interpolate(
preds[mvar].cpu()[idx_image:idx_image+1], size=preds[imvar][0, 0].size(), mode='nearest'
)[0]).data.cpu().numpy()
p = np.stack([p_.view(2,-1).numpy() for p_ in (p1, p2)])
vis_utils.show_projections( viz, visdom_env_imgs, p, v=None,
title=title, cmap__='rainbow',
markersize=1, sticks=None,
stickwidth=1, plot_point_order=False,
image=img,
win='projections_dense' )
vis_utils.show_flow(viz, visdom_env_imgs, p,
image=preds[imvar][idx_image].detach().cpu().numpy(),
title='flow ' + title,
linewidth=1,
win='projections_flow',
)
if 'sph_sample_projs' in preds:
p = preds['sph_sample_projs'][idx_image].detach().cpu().view(2, -1)
if 'sph_sample_gt' in preds:
p_ = preds['sph_sample_gt'][idx_image].detach().cpu().view(2, -1)
p_ = p_.repeat(1, math.ceil(p.shape[1]/p_.shape[1]))
p = [p, p_[:, :p.shape[1]]]
else:
p = [p.view(2, -1)]
# p = (torch.stack(p) + 1.) / 2.
p = (torch.stack(p) + 1.) / 2.
imsize = preds[imvar][idx_image].shape[1:]
p[:, 0, :] *= imsize[1]
p[:, 1, :] *= imsize[0]
vis_utils.show_projections(viz, visdom_env_imgs,
p, v=None,
title=title + '_spl_sil',
cmap__='rainbow',
markersize=1, sticks=None,
stickwidth=1, plot_point_order=False,
image=preds[imvar][idx_image].detach().cpu().numpy(),
win='projections_spl_sil'
)
merged_embed = self._merge_masked_tensors(
preds['embed_full'], preds['embed_masks']
)[..., None]
gl_desc_0 = {k: v[:1] for k, v in preds['phi'].items()}
merged_with_pivot_phis = self._get_shapes_and_projections(
merged_embed, None, gl_desc_0, preds['K'][:1]
)
preds['shape_canonical_same_alphas'] = merged_with_pivot_phis[
'shape_canonical_dense'
][0 ,..., 0]
# dense 3d
pcl_show = {}
vis_list = ['dense3d', 'mean_shape', 'embed_db', 'batch_fused', 'sph_embed']
if self.loss_weights['loss_sph_sample_mask'] > 0:
vis_list.append('sph_sample_3d')
for vis in vis_list:
if vis=='canonical':
pcl = preds['shape_canonical_dense']
elif vis=='dense3d':
pcl = preds['shape_image_coord_cal_dense']
elif vis=='batch_fused':
pcl = preds['shape_canonical_same_alphas'].detach().cpu()
pcl = torch.cat((pcl, pcl), dim=0)
pcl[3:5,:] = 0.0
pcl[5,:] = 1.0
elif vis=='mean_shape':
pcl = preds['embed_mean']
elif vis=='mean_c3dpo_shape':
pcl = preds['nrsfm_mean_shape']
elif vis=='shape_canonical':
pcl = preds['shape_canonical_dense']
elif vis == 'sph_embed':
pcl = preds['embed'].detach().clone()
elif vis == 'sph_sample_3d':
pcl = preds['sph_sample_3d'][idx_image].detach().cpu().view(3, -1)
pcl = torch.cat((pcl, pcl.clone()), dim=0)
pcl[4:,:] = 0.0
pcl[3,:] = 1.0
# filtering outliers
pcl[:3] -= pcl[:3].mean(dim=1, keepdim=True) # will be centered anyway
std = pcl[:3].std(dim=1).max()
pcl[:3] = pcl[:3].clamp(-2.5*std, 2.5*std)
elif vis == 'embed_db':
pcl = self.embed_db.get_db(uniform_sphere=False).cpu().detach().view(3, -1)
pcl = torch.cat((pcl, pcl.clone()), dim=0)
pcl[3:5,:] = 0.0
pcl[4,:] = 1.0
else:
raise ValueError(vis)
if vis not in ('mean_c3dpo_shape', 'batch_fused', 'sph_sample_3d', 'embed_db'):
pcl_rgb = preds[imvar].detach().cpu()
#pcl = Fu.interpolate(pcl.detach().cpu(), pcl_rgb.shape[2:], mode='bilinear')
pcl_rgb = Fu.interpolate(pcl_rgb, size=pcl.shape[2:], mode='bilinear')
if (mvar in preds):
masks = preds[mvar].detach().cpu()
masks = Fu.interpolate(masks, \
size=pcl.shape[2:], mode='nearest')
else:
masks = None
pcl = pcl.detach().cpu()[idx_image].view(3,-1)
pcl_rgb = pcl_rgb[idx_image].view(3,-1)
pcl = torch.cat((pcl, pcl_rgb), dim=0)
if masks is not None:
masks = masks[idx_image].view(-1)
pcl = pcl[:,masks>0.]
# if vis == 'sph_embed':
# import pdb; pdb.set_trace()
if pcl.numel()==0:
continue
pcl_show[vis] = pcl.numpy()
vis_utils.visdom_plotly_pointclouds(viz, pcl_show, visdom_env_imgs,
title=title+'_'+vis,
markersize=1,
sticks=None, win=vis,
height=700, width=700 ,
normalise=True,
)
var3d = {
'orthographic': 'shape_image_coord',
'perspective': 'shape_image_coord_cal',
}[self.projection_type]
sparse_pcl = {
'nrsfm': preds['nrsfm'][var3d][idx_image].detach().cpu().numpy().copy(),
'dense': preds['shape_image_coord_cal'][idx_image].detach().cpu().numpy().copy(),
}
if 'kp_loc_3d' in preds:
sparse_pcl['gt'] = preds['kp_loc_3d'][idx_image].detach().cpu().numpy().copy()
if 'class_mask' in preds:
class_mask = preds['class_mask'][idx_image].detach().cpu().numpy()
sparse_pcl = {k: v*class_mask[None] for k,v in sparse_pcl.items()}
vis_utils.visdom_plotly_pointclouds(viz, sparse_pcl, visdom_env_imgs, \
title=title+'_sparse3d', \
markersize=5, \
sticks=None, win='nrsfm_3d',
height=500,
width=500 )
if 'photo_out' in preds and preds['photo_out'] is not None:
# show the source images and their renders
ims_src = preds['photo_out']['images'].detach().cpu()
ims_repro = preds['photo_out']['images_reproject'].detach().cpu()
ims_reenact = preds['photo_out']['images_reenact'].detach().cpu()
ims_gt = preds['photo_out']['images_gt'].detach().cpu()
# cat all the images
ims = torch.cat((ims_src,ims_reenact,ims_repro,ims_gt), dim=2)
ims = torch.clamp(ims,0.,1.)
viz.images(ims, env=visdom_env_imgs, opts={'title':title}, win='imrepro')
im_renders = preds['photo_out']['image_ref_render']
for l in im_renders:
im_gt = preds['photo_out']['images_gt'][0].detach().cpu()
im_render = im_renders[l].detach().cpu()
im = torch.cat((im_gt, im_render), dim=2)
im = torch.clamp(im, 0., 1.)
viz.image(im, env=visdom_env_imgs, \
opts={'title':title+'_min_render_%s' % l}, win='imrender_%s' % l)
if 'app_out' in preds and preds['app_out'] is not None:
# show the source images and their predictions
ims_src = preds['app_out']['images'].detach().cpu()
ims_pred = preds['app_out']['images_pred_clamp'].detach().cpu()
ims = torch.cat((ims_src,ims_pred), dim=2)
viz.images(ims, env=visdom_env_imgs, opts={'title':title}, win='impred')