in models/gan_visualizer.py [0:0]
def visualizeNN(self,
N,
k,
featureExtractor,
imgTransform,
nnSearch,
names,
pathDB):
r"""
Visualize the nearest neighbors of some random generations
Args:
N (int): number of generation to make
k (int): number of neighbors to fetch
featureExtractor (nn.Module): feature extractor
imgTransform (nn.Module): image transform module
nnSearch (np.KDTree): serach tree for the features
names (list): a match between an image index and its name
"""
batchSize = 16
nImages = 0
vectorOut = []
size = self.model.getSize()[0]
transform = Transforms.Compose([NumpyResize((size, size)),
NumpyToTensor(),
Transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))])
dataset = None
if os.path.splitext(pathDB)[1] == ".h5":
dataset = H5Dataset(pathDB,
transform=Transforms.Compose(
[NumpyToTensor(),
Transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]))
while nImages < N:
noiseData, _ = self.model.buildNoiseData(batchSize)
imgOut = self.model.test(
noiseData, getAvG=True, toCPU=False).detach()
features = featureExtractor(imgTransform(imgOut)).detach().view(
imgOut.size(0), -1).cpu().numpy()
distances, indexes = nnSearch.query(features, k)
nImages += batchSize
for p in range(N):
vectorOut.append(imgOut[p].cpu().view(
1, imgOut.size(1), imgOut.size(2), imgOut.size(3)))
for ki in range(k):
i = indexes[p][ki]
if dataset is None:
path = os.path.join(pathDB, names[i])
imgSource = transform(pil_loader(path))
imgSource = imgSource.view(1, imgSource.size(
0), imgSource.size(1), imgSource.size(2))
else:
imgSource, _ = dataset[names[i]]
imgSource = imgSource.view(1, imgSource.size(
0), imgSource.size(1), imgSource.size(2))
imgSource = F.upsample(
imgSource, size=(size, size), mode='bilinear')
vectorOut.append(imgSource)
vectorOut = torch.cat(vectorOut, dim=0)
self.visualizer.publishTensors(vectorOut, (224, 224), nrow=k + 1)