in lib/model/HGPIFuNetwNML.py [0:0]
def query(self, points, calibs, transforms=None, labels=None, update_pred=True, update_phi=True):
'''
given 3d points, we obtain 2d projection of these given the camera matrices.
filter needs to be called beforehand.
the prediction is stored to self.preds
args:
points: [B, 3, N] 3d points in world space
calibs: [B, 3, 4] calibration matrices for each image
transforms: [B, 2, 3] image space coordinate transforms
labels: [B, C, N] ground truth labels (for supervision only)
return:
[B, C, N] prediction
'''
xyz = self.projection(points, calibs, transforms)
xy = xyz[:, :2, :]
# if the point is outside bounding box, return outside.
in_bb = (xyz >= -1) & (xyz <= 1)
in_bb = in_bb[:, 0, :] & in_bb[:, 1, :] & in_bb[:, 2, :]
in_bb = in_bb[:, None, :].detach().float()
if labels is not None:
self.labels = in_bb * labels
sp_feat = self.spatial_enc(xyz, calibs=calibs)
intermediate_preds_list = []
phi = None
for i, im_feat in enumerate(self.im_feat_list):
point_local_feat_list = [self.index(im_feat, xy), sp_feat]
point_local_feat = torch.cat(point_local_feat_list, 1)
pred, phi = self.mlp(point_local_feat)
pred = in_bb * pred
intermediate_preds_list.append(pred)
if update_phi:
self.phi = phi
if update_pred:
self.intermediate_preds_list = intermediate_preds_list
self.preds = self.intermediate_preds_list[-1]