in models/base_ssl3d_model.py [0:0]
def _batch_shuffle_ddp(self, x, vox=False, idx_shuffle=None):
"""
Batch shuffle, for making use of BatchNorm.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
if vox:
batch_size = []
for bidx in x:
batch_size.append(len(bidx))
all_size = concat_all_gather(torch.tensor(batch_size).cuda())
max_size = torch.max(all_size)
### Change the new size here
newx = []
for bidx in range(len(x)):
newx.append(torch.ones((max_size, x[bidx].shape[1])).cuda())
newx[bidx][:len(x[bidx]),:] = x[bidx]
newx = torch.stack(newx)
batch_size_this = newx.shape[0]
else:
batch_size_this = x.shape[0]
if vox:
x_gather = concat_all_gather(newx)
else:
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
if idx_shuffle == None:
# random shuffle index
idx_shuffle = torch.randperm(batch_size_all).cuda()
# broadcast to all gpus
torch.distributed.broadcast(idx_shuffle, src=0)
# index for restoring
idx_unshuffle = torch.argsort(idx_shuffle)
# shuffled index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
if vox:
ret_x = []
batch_idx = []
for idx in range(len(idx_this)):
if x_gather.shape[-1] == 4:
### Change the batch index here
tempdata = x_gather[idx_this[idx]][:all_size[idx_this[idx]],:]
tempdata[:,0] = idx
ret_x.append(tempdata)
else:
ret_x.append(x_gather[idx_this[idx]][:all_size[idx_this[idx]],:])
ret_x = torch.cat(ret_x)
return ret_x, idx_unshuffle, idx_shuffle
else:
return x_gather[idx_this], idx_unshuffle