in models/base_ssl3d_model.py [0:0]
def _single_input_forward_MOCO(self, batch, feature_names, input_key, target):
if "vox" not in input_key:
assert isinstance(batch, torch.Tensor)
if ('vox' in input_key) and ("Lidar" not in self.config):
points = batch
points_coords = points[0]
points_feats = points[1]
### Invariant to even and odd coords
points_coords[:, 1:] += (torch.rand(3) * 100).type_as(points_coords)
points_feats = points_feats/255.0 - 0.5
### If enable shuffle batch for vox, please comment out this line.
batch = SparseTensor(points_feats, points_coords.float())
with torch.no_grad():
self._momentum_update_key(target) # update the key encoder
# shuffle for making use of BN
if torch.distributed.is_initialized():
if "vox" not in input_key:
batch, idx_unshuffle = self._batch_shuffle_ddp(batch, vox=False)
if False:
### Skip batch shuffle for vox for now
### Does not give performance gain
if ("Lidar" not in self.config):
batch_inds = points_coords[:,0].detach().cpu().numpy()
points_coords = main_utils.recursive_copy_to_gpu(
points_coords, non_blocking=True
)
points_feats = main_utils.recursive_copy_to_gpu(
points_feats, non_blocking=True
)
point_coord_split = []
point_feat_split = []
for batch_ind in np.unique(batch_inds):
point_coord_split.append(points_coords[points_coords[:,0]==batch_ind])
point_feat_split.append(points_feats[points_coords[:,0]==batch_ind])
points_coords, idx_unshuffle, idx_shuffle = self._batch_shuffle_ddp(point_coord_split, vox=True)
points_feats, _, _ = self._batch_shuffle_ddp(point_feat_split, vox=True, idx_shuffle=idx_shuffle)
batch = SparseTensor(points_feats, points_coords.float())
else:
print ("Not implemented yet")
else:
if ('vox' in input_key) and ("Lidar" not in self.config):
batch = SparseTensor(points_feats, points_coords.float())
# Copy to GPU
if ("Lidar" in self.config) and ("vox" in input_key):
for key in batch:
batch[key] = main_utils.recursive_copy_to_gpu(
batch[key], non_blocking=True
)
else:
batch = main_utils.recursive_copy_to_gpu(
batch, non_blocking=True
)
feats = self.trunk[target](batch, feature_names)
if torch.distributed.is_initialized():
if "vox" not in input_key:
feats = [self._batch_unshuffle_ddp(feats[0], idx_unshuffle)]
return feats
else:
return feats