def saveFeatures()

in models/metrics/nn_score.py [0:0]


def saveFeatures(model,
                 imgTransform,
                 pathDB,
                 pathMask,
                 pathAttrib,
                 outputFile,
                 pathPartition=None,
                 partitionValue=None):

    batchSize = 16

    transformList = [Transforms.ToTensor(),
                     Transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]

    transform = Transforms.Compose(transformList)

    device = torch.device("cuda:0")
    n_devices = torch.cuda.device_count()

    parallelModel = nn.DataParallel(model).to(device).eval()
    parallelTransorm = nn.DataParallel(imgTransform).to(device)

    if os.path.splitext(pathDB)[1] == ".h5":
        dataset = H5Dataset(pathDB,
                            transform=transform,
                            pathDBMask=pathMask,
                            partition_path=pathPartition,
                            partition_value=partitionValue)

    else:
        dataset = AttribDataset(pathDB, transform=transform,
                                attribDictPath=pathAttrib,
                                specificAttrib=None,
                                mimicImageFolder=False,
                                pathMask=pathMask)

    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batchSize,
                                         shuffle=False,
                                         num_workers=n_devices)

    outFeatures = []

    nImg = 0
    totImg = len(dataset)

    for item in loader:

        if len(item) == 3:
            data, label, mask = item
        else:
            data, label = item
            mask = None

        printProgressBar(nImg, totImg)
        features = parallelModel(parallelTransorm(
            data)).detach().view(data.size(0), -1).cpu()
        outFeatures.append(features)

        nImg += batchSize

    printProgressBar(totImg, totImg)

    import sys
    sys.setrecursionlimit(10000)

    outFeatures = torch.cat(outFeatures, dim=0).numpy()
    tree = scipy.spatial.KDTree(outFeatures, leafsize=10)
    names = [dataset.getName(x) for x in range(totImg)]
    with open(outputFile, 'wb') as file:
        pickle.dump([tree, names], file)