in lib/model/HGPIFuMRNet.py [0:0]
def query(self, points, calib_local, calib_global=None, transforms=None, labels=None):
'''
given 3d points, we obtain 2d projection of these given the camera matrices.
filter needs to be called beforehand.
the prediction is stored to self.preds
args:
points: [B1, B2, 3, N] 3d points in world space
calibs_local: [B1, B2, 4, 4] calibration matrices for each image
calibs_global: [B1, 4, 4] calibration matrices for each image
transforms: [B1, 2, 3] image space coordinate transforms
labels: [B1, B2, C, N] ground truth labels (for supervision only)
return:
[B, C, N] prediction
'''
if calib_global is not None:
B = calib_local.size(1)
else:
B = 1
points = points[:,None]
calib_global = calib_local
calib_local = calib_local[:,None]
ws = []
preds = []
preds_interm = []
preds_low = []
gammas = []
newlabels = []
for i in range(B):
xyz = self.projection(points[:,i], calib_local[:,i], transforms)
xy = xyz[:, :2, :]
# if the point is outside bounding box, return outside.
in_bb = (xyz >= -1) & (xyz <= 1)
in_bb = in_bb[:, 0, :] & in_bb[:, 1, :]
in_bb = in_bb[:, None, :].detach().float()
self.netG.query(points=points[:,i], calibs=calib_global)
preds_low.append(torch.stack(self.netG.intermediate_preds_list,0))
if labels is not None:
newlabels.append(in_bb * labels[:,i])
with torch.no_grad():
ws.append(in_bb.size(2) / in_bb.view(in_bb.size(0),-1).sum(1))
gammas.append(1 - newlabels[-1].view(newlabels[-1].size(0),-1).sum(1) / in_bb.view(in_bb.size(0),-1).sum(1))
z_feat = self.netG.phi
if not self.opt.train_full_pifu:
z_feat = z_feat.detach()
intermediate_preds_list = []
for j, im_feat in enumerate(self.im_feat_list):
point_local_feat_list = [self.index(im_feat.view(-1,B,*im_feat.size()[1:])[:,i], xy), z_feat]
point_local_feat = torch.cat(point_local_feat_list, 1)
pred = self.mlp(point_local_feat)[0]
pred = in_bb * pred
intermediate_preds_list.append(pred)
preds_interm.append(torch.stack(intermediate_preds_list,0))
preds.append(intermediate_preds_list[-1])
self.preds = torch.cat(preds,0)
self.preds_interm = torch.cat(preds_interm, 1) # first dim is for intermediate predictions
self.preds_low = torch.cat(preds_low, 1) # first dim is for intermediate predictions
if labels is not None:
self.w = torch.cat(ws,0)
self.gamma = torch.cat(gammas,0)
self.labels = torch.cat(newlabels,0)