def visualizeNN()

in models/gan_visualizer.py [0:0]


    def visualizeNN(self,
                    N,
                    k,
                    featureExtractor,
                    imgTransform,
                    nnSearch,
                    names,
                    pathDB):
        r"""
        Visualize the nearest neighbors of some random generations

        Args:

            N (int): number of generation to make
            k (int): number of neighbors to fetch
            featureExtractor (nn.Module): feature extractor
            imgTransform (nn.Module): image transform module
            nnSearch (np.KDTree): serach tree for the features
            names (list): a match between an image index and its name
        """

        batchSize = 16
        nImages = 0

        vectorOut = []

        size = self.model.getSize()[0]

        transform = Transforms.Compose([NumpyResize((size, size)),
                                        NumpyToTensor(),
                                        Transforms.Normalize((0.5, 0.5, 0.5),
                                        (0.5, 0.5, 0.5))])

        dataset = None

        if os.path.splitext(pathDB)[1] == ".h5":
            dataset = H5Dataset(pathDB,
                                transform=Transforms.Compose(
                                [NumpyToTensor(),
                                 Transforms.Normalize((0.5, 0.5, 0.5),
                                                      (0.5, 0.5, 0.5))]))

        while nImages < N:

            noiseData, _ = self.model.buildNoiseData(batchSize)
            imgOut = self.model.test(
                noiseData, getAvG=True, toCPU=False).detach()

            features = featureExtractor(imgTransform(imgOut)).detach().view(
                imgOut.size(0), -1).cpu().numpy()
            distances, indexes = nnSearch.query(features, k)
            nImages += batchSize

            for p in range(N):

                vectorOut.append(imgOut[p].cpu().view(
                    1, imgOut.size(1), imgOut.size(2), imgOut.size(3)))
                for ki in range(k):

                    i = indexes[p][ki]
                    if dataset is None:
                        path = os.path.join(pathDB, names[i])
                        imgSource = transform(pil_loader(path))
                        imgSource = imgSource.view(1, imgSource.size(
                            0), imgSource.size(1), imgSource.size(2))

                    else:
                        imgSource, _ = dataset[names[i]]
                        imgSource = imgSource.view(1, imgSource.size(
                            0), imgSource.size(1), imgSource.size(2))
                        imgSource = F.upsample(
                            imgSource, size=(size, size), mode='bilinear')

                    vectorOut.append(imgSource)

        vectorOut = torch.cat(vectorOut, dim=0)
        self.visualizer.publishTensors(vectorOut, (224, 224), nrow=k + 1)