def gradientDescentOnInput()

in models/eval/inspirational_generation.py [0:0]


def gradientDescentOnInput(model,
                           input,
                           featureExtractors,
                           imageTransforms,
                           weights=None,
                           visualizer=None,
                           lambdaD=0.03,
                           nSteps=6000,
                           randomSearch=False,
                           nevergrad=None,
                           lr=1,
                           outPathSave=None):
    r"""
    Performs a similarity search with gradient descent.

    Args:

        model (BaseGAN): trained GAN model to use
        input (tensor): inspiration images for the gradient descent. It should
                        be a [NxCxWxH] tensor with N the number of image, C the
                        number of color channels (typically 3), W the image
                        width and H the image height
        featureExtractors (nn.module): list of networks used to extract features
                                       from an image
        weights (list of float): if not None, weight to give to each feature
                                 extractor in the loss criterion
        visualizer (visualizer): if not None, visualizer to use to plot
                                 intermediate results
        lambdaD (float): weight of the realism loss
        nSteps (int): number of steps to perform
        randomSearch (bool): if true, replace tha gradient descent by a random
                             search
        nevergrad (string): must be in None or in ['CMA', 'DE', 'PSO',
                            'TwoPointsDE', 'PortfolioDiscreteOnePlusOne',
                            'DiscreteOnePlusOne', 'OnePlusOne']
        outPathSave (string): if not None, path to save the intermediate
                              iterations of the gradient descent
    Returns

        output, optimalVector, optimalLoss

        output (tensor): output images
        optimalVector (tensor): latent vectors corresponding to the output
                                images
    """

    if nevergrad not in [None, 'CMA', 'DE', 'PSO',
                         'TwoPointsDE', 'PortfolioDiscreteOnePlusOne',
                         'DiscreteOnePlusOne', 'OnePlusOne']:
        raise ValueError("Invalid nevergard mode " + str(nevergrad))
    randomSearch = randomSearch or (nevergrad is not None)
    print("Running for %d setps" % nSteps)

    if visualizer is not None:
        visualizer.publishTensors(input, (128, 128))

    # Detect categories
    varNoise = torch.randn((input.size(0),
                            model.config.noiseVectorDim +
                            model.config.categoryVectorDim),
                           requires_grad=True, device=model.device)

    optimNoise = optim.Adam([varNoise],
                            betas=[0., 0.99], lr=lr)

    noiseOut = model.test(varNoise, getAvG=True, toCPU=False)

    if not isinstance(featureExtractors, list):
        featureExtractors = [featureExtractors]
    if not isinstance(imageTransforms, list):
        imageTransforms = [imageTransforms]

    nExtractors = len(featureExtractors)

    if weights is None:
        weights = [1.0 for i in range(nExtractors)]

    if len(imageTransforms) != nExtractors:
        raise ValueError(
            "The number of image transforms should match the number of \
            feature extractors")
    if len(weights) != nExtractors:
        raise ValueError(
            "The number of weights should match the number of feature\
             extractors")

    featuresIn = []
    for i in range(nExtractors):

        if len(featureExtractors[i]._modules) > 0:
            featureExtractors[i] = nn.DataParallel(
                featureExtractors[i]).train().to(model.device)

        featureExtractors[i].eval()
        imageTransforms[i] = nn.DataParallel(
            imageTransforms[i]).to(model.device)

        featuresIn.append(featureExtractors[i](
            imageTransforms[i](input.to(model.device))).detach())

        if nevergrad is None:
            featureExtractors[i].train()

    lr = 1

    optimalVector = None
    optimalLoss = None

    epochStep = int(nSteps / 3)
    gradientDecay = 0.1

    nImages = input.size(0)
    print(f"Generating {nImages} images")
    if nevergrad is not None:
        optimizers = []
        for i in range(nImages):
            optimizers += [optimizerlib.registry[nevergrad](
                dimension=model.config.noiseVectorDim +
                model.config.categoryVectorDim,
                budget=nSteps)]

    def resetVar(newVal):
        newVal.requires_grad = True
        print("Updating the optimizer with learning rate : %f" % lr)
        varNoise = newVal
        optimNoise = optim.Adam([varNoise],
                                betas=[0., 0.99], lr=lr)

    # String's format for loss output
    formatCommand = ' '.join(['{:>4}' for x in range(nImages)])
    for iter in range(nSteps):

        optimNoise.zero_grad()
        model.netG.zero_grad()
        model.netD.zero_grad()

        if randomSearch:
            varNoise = torch.randn((nImages,
                                    model.config.noiseVectorDim +
                                    model.config.categoryVectorDim),
                                   device=model.device)
            if nevergrad:
                inps = []
                for i in range(nImages):
                    inps += [optimizers[i].ask()]
                    npinps = np.array(inps)

                varNoise = torch.tensor(
                    npinps, dtype=torch.float32, device=model.device)
                varNoise.requires_grad = True
                varNoise.to(model.device)

        noiseOut = model.netG(varNoise)
        sumLoss = torch.zeros(nImages, device=model.device)

        loss = (((varNoise**2).mean(dim=1) - 1)**2)
        sumLoss += loss.view(nImages)
        loss.sum(dim=0).backward(retain_graph=True)

        for i in range(nExtractors):
            featureOut = featureExtractors[i](imageTransforms[i](noiseOut))
            diff = ((featuresIn[i] - featureOut)**2)
            loss = weights[i] * diff.mean(dim=1)
            sumLoss += loss

            if not randomSearch:
                retainGraph = (lambdaD > 0) or (i != nExtractors - 1)
                loss.sum(dim=0).backward(retain_graph=retainGraph)

        if lambdaD > 0:

            loss = -lambdaD * model.netD(noiseOut)[:, 0]
            sumLoss += loss

            if not randomSearch:
                loss.sum(dim=0).backward()

        if nevergrad:
            for i in range(nImages):
                optimizers[i].tell(inps[i], float(sumLoss[i]))
        elif not randomSearch:
            optimNoise.step()

        if optimalLoss is None:
            optimalVector = deepcopy(varNoise)
            optimalLoss = sumLoss

        else:
            optimalVector = torch.where(sumLoss.view(-1, 1) < optimalLoss.view(-1, 1),
                                        varNoise, optimalVector).detach()
            optimalLoss = torch.where(sumLoss < optimalLoss,
                                      sumLoss, optimalLoss).detach()

        if iter % 100 == 0:
            if visualizer is not None:
                visualizer.publishTensors(noiseOut.cpu(), (128, 128))

                if outPathSave is not None:
                    index_str = str(int(iter/100))
                    outPath = os.path.join(outPathSave, index_str + ".jpg")
                    visualizer.saveTensor(
                        noiseOut.cpu(),
                        (noiseOut.size(2), noiseOut.size(3)),
                        outPath)

            print(str(iter) + " : " + formatCommand.format(
                *["{:10.6f}".format(sumLoss[i].item())
                  for i in range(nImages)]))

        if iter % epochStep == (epochStep - 1):
            lr *= gradientDecay
            resetVar(optimalVector)

    output = model.test(optimalVector, getAvG=True, toCPU=True).detach()

    if visualizer is not None:
        visualizer.publishTensors(
            output.cpu(), (output.size(2), output.size(3)))

    print("optimal losses : " + formatCommand.format(
        *["{:10.6f}".format(optimalLoss[i].item())
          for i in range(nImages)]))
    return output, optimalVector, optimalLoss