def find_correspondences()

in criterions/avid_cma.py [0:0]


    def find_correspondences(self):
        if self.sampling_args['pos_k'] <= 0:
            return

        # Find CMA correspondences. Only do this on one process if running in distributed mode and sync at the end.
        positive_set = np.zeros((self.view1_mem.shape[0], self.sampling_args['pos_k'])).astype(int)
        if not self.distributed or self.distributed and self.rank == 0:
            torch.cuda.empty_cache()
            positive_set = CMASampler(self.view1_mem, self.view2_mem, self.sampling_args).sample()

        # Find CMA correspondences. Only do this on one process if running in distributed mode and sync at the end.
        if positive_set is not None:
            self.register_buffer('positive_set', torch.from_numpy(positive_set).int())
            self.positive_set = self.positive_set.cuda(self.device)
            if self.distributed:
                dist.broadcast(self.positive_set, 0)

        if self.distributed:
            dist.barrier()