def buildFeatureMaker()

in models/metrics/nn_score.py [0:0]


def buildFeatureMaker(pathDB,
                      pathTrainAttrib,
                      pathValAttrib,
                      specificAttrib=None,
                      visualisation=None):

    # Parameters
    batchSize = 16
    nEpochs = 3
    learningRate = 1e-4
    beta1 = 0.9
    beta2 = 0.99
    device = torch.device("cuda:0")
    n_devices = torch.cuda.device_count()

    # Model
    resnet18 = models.resnet18(pretrained=True)
    resnet18.train()

    # Dataset
    size = 224
    transformList = [Transforms.Resize((size, size)),
                     Transforms.RandomHorizontalFlip(),
                     Transforms.ToTensor(),
                     Transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]

    transform = Transforms.Compose(transformList)

    dataset = AttribDataset(pathDB, transform=transform,
                            attribDictPath=pathTrainAttrib,
                            specificAttrib=specificAttrib,
                            mimicImageFolder=False)

    validationDataset = AttribDataset(pathDB, transform=transform,
                                      attribDictPath=pathValAttrib,
                                      specificAttrib=specificAttrib,
                                      mimicImageFolder=False)

    print("%d training images detected, %d validation images detected"
          % (len(dataset), len(validationDataset)))

    # Optimization
    optimizer = torch.optim.Adam(resnet18.parameters(),
                                 betas=[beta1, beta2],
                                 lr=learningRate)

    lossMode = ACGANCriterion(dataset.getKeyOrders())

    num_ftrs = resnet18.fc.in_features
    resnet18.fc = nn.Linear(num_ftrs, lossMode.getInputDim())
    resnet18 = nn.DataParallel(resnet18).to(device)

    # Visualization data
    lossTrain = []
    lossVal = []
    iterList = []
    tokenTrain = None
    tokenVal = None
    step = 0
    tmpLoss = 0

    for epoch in range(nEpochs):

        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batchSize,
                                             shuffle=True,
                                             num_workers=n_devices)

        for iter, data in enumerate(loader):

            optimizer.zero_grad()

            inputs_real, labels = data
            inputs_real = inputs_real.to(device)
            labels = labels.to(device)

            predictedLabels = resnet18(inputs_real)

            loss = lossMode.getLoss(predictedLabels, labels)

            tmpLoss += loss.item()

            loss.backward()
            optimizer.step()

            if step % 100 == 0 and visualisation is not None:

                divisor = 100
                if step == 0:
                    divisor = 1
                lossTrain.append(tmpLoss / divisor)
                iterList.append(step)
                tokenTrain = visualisation.publishLinePlot([('lossTrain', lossTrain)], iterList,
                                                           name="Loss train",
                                                           window_token=tokenTrain,
                                                           env="main")

                validationLoader = torch.utils.data.DataLoader(validationDataset,
                                                               batch_size=batchSize,
                                                               shuffle=True,
                                                               num_workers=n_devices)

                resnet18.eval()
                lossEval = 0
                i = 0
                for valData in validationLoader:

                    inputs_real, labels = data
                    inputs_real = inputs_real.to(device)
                    labels = labels.to(device)
                    lossEval += lossMode.getLoss(predictedLabels,
                                                 labels).item()
                    i += 1

                    if i == 100:
                        break

                lossEval /= i
                lossVal.append(lossEval)
                tokenVal = visualisation.publishLinePlot([('lossValidation', lossVal)], iterList,
                                                         name="Loss validation",
                                                         window_token=tokenVal,
                                                         env="main")
                resnet18.train()

                print("[%5d ; %5d ] Loss train : %f ; Loss validation %f"
                      % (epoch, iter, tmpLoss / divisor, lossEval))
                tmpLoss = 0

            step += 1

    return resnet18.module