in criterions/avid_cma.py [0:0]
def find_correspondences(self):
if self.sampling_args['pos_k'] <= 0:
return
# Find CMA correspondences. Only do this on one process if running in distributed mode and sync at the end.
positive_set = np.zeros((self.view1_mem.shape[0], self.sampling_args['pos_k'])).astype(int)
if not self.distributed or self.distributed and self.rank == 0:
torch.cuda.empty_cache()
positive_set = CMASampler(self.view1_mem, self.view2_mem, self.sampling_args).sample()
# Find CMA correspondences. Only do this on one process if running in distributed mode and sync at the end.
if positive_set is not None:
self.register_buffer('positive_set', torch.from_numpy(positive_set).int())
self.positive_set = self.positive_set.cuda(self.device)
if self.distributed:
dist.broadcast(self.positive_set, 0)
if self.distributed:
dist.barrier()