in models/metrics/nn_score.py [0:0]
def saveFeatures(model,
imgTransform,
pathDB,
pathMask,
pathAttrib,
outputFile,
pathPartition=None,
partitionValue=None):
batchSize = 16
transformList = [Transforms.ToTensor(),
Transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
transform = Transforms.Compose(transformList)
device = torch.device("cuda:0")
n_devices = torch.cuda.device_count()
parallelModel = nn.DataParallel(model).to(device).eval()
parallelTransorm = nn.DataParallel(imgTransform).to(device)
if os.path.splitext(pathDB)[1] == ".h5":
dataset = H5Dataset(pathDB,
transform=transform,
pathDBMask=pathMask,
partition_path=pathPartition,
partition_value=partitionValue)
else:
dataset = AttribDataset(pathDB, transform=transform,
attribDictPath=pathAttrib,
specificAttrib=None,
mimicImageFolder=False,
pathMask=pathMask)
loader = torch.utils.data.DataLoader(dataset,
batch_size=batchSize,
shuffle=False,
num_workers=n_devices)
outFeatures = []
nImg = 0
totImg = len(dataset)
for item in loader:
if len(item) == 3:
data, label, mask = item
else:
data, label = item
mask = None
printProgressBar(nImg, totImg)
features = parallelModel(parallelTransorm(
data)).detach().view(data.size(0), -1).cpu()
outFeatures.append(features)
nImg += batchSize
printProgressBar(totImg, totImg)
import sys
sys.setrecursionlimit(10000)
outFeatures = torch.cat(outFeatures, dim=0).numpy()
tree = scipy.spatial.KDTree(outFeatures, leafsize=10)
names = [dataset.getName(x) for x in range(totImg)]
with open(outputFile, 'wb') as file:
pickle.dump([tree, names], file)