def _batch_shuffle_ddp()

in models/base_ssl3d_model.py [0:0]


    def _batch_shuffle_ddp(self, x, vox=False, idx_shuffle=None):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        if vox:
            batch_size = []
            for bidx in x:
                batch_size.append(len(bidx))
            all_size = concat_all_gather(torch.tensor(batch_size).cuda())
            max_size = torch.max(all_size)

            ### Change the new size here
            newx = []
            for bidx in range(len(x)):
                newx.append(torch.ones((max_size, x[bidx].shape[1])).cuda())
                newx[bidx][:len(x[bidx]),:] = x[bidx]
            newx = torch.stack(newx)
            batch_size_this = newx.shape[0]
        else:
            batch_size_this = x.shape[0]

        if vox:
            x_gather = concat_all_gather(newx)
        else:
            x_gather = concat_all_gather(x)

        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        if idx_shuffle == None:
            # random shuffle index
            idx_shuffle = torch.randperm(batch_size_all).cuda()
        
            # broadcast to all gpus
            torch.distributed.broadcast(idx_shuffle, src=0)
        
        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)
        
        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        if vox:
            ret_x = []
            batch_idx = []
            for idx in range(len(idx_this)):
                if x_gather.shape[-1] == 4:
                    ### Change the batch index here
                    tempdata = x_gather[idx_this[idx]][:all_size[idx_this[idx]],:]
                    tempdata[:,0] = idx
                    ret_x.append(tempdata)
                else:
                    ret_x.append(x_gather[idx_this[idx]][:all_size[idx_this[idx]],:])
            ret_x = torch.cat(ret_x)
            return ret_x, idx_unshuffle, idx_shuffle
        else:
            return x_gather[idx_this], idx_unshuffle