in models/eval/inspirational_generation.py [0:0]
def gradientDescentOnInput(model,
input,
featureExtractors,
imageTransforms,
weights=None,
visualizer=None,
lambdaD=0.03,
nSteps=6000,
randomSearch=False,
nevergrad=None,
lr=1,
outPathSave=None):
r"""
Performs a similarity search with gradient descent.
Args:
model (BaseGAN): trained GAN model to use
input (tensor): inspiration images for the gradient descent. It should
be a [NxCxWxH] tensor with N the number of image, C the
number of color channels (typically 3), W the image
width and H the image height
featureExtractors (nn.module): list of networks used to extract features
from an image
weights (list of float): if not None, weight to give to each feature
extractor in the loss criterion
visualizer (visualizer): if not None, visualizer to use to plot
intermediate results
lambdaD (float): weight of the realism loss
nSteps (int): number of steps to perform
randomSearch (bool): if true, replace tha gradient descent by a random
search
nevergrad (string): must be in None or in ['CMA', 'DE', 'PSO',
'TwoPointsDE', 'PortfolioDiscreteOnePlusOne',
'DiscreteOnePlusOne', 'OnePlusOne']
outPathSave (string): if not None, path to save the intermediate
iterations of the gradient descent
Returns
output, optimalVector, optimalLoss
output (tensor): output images
optimalVector (tensor): latent vectors corresponding to the output
images
"""
if nevergrad not in [None, 'CMA', 'DE', 'PSO',
'TwoPointsDE', 'PortfolioDiscreteOnePlusOne',
'DiscreteOnePlusOne', 'OnePlusOne']:
raise ValueError("Invalid nevergard mode " + str(nevergrad))
randomSearch = randomSearch or (nevergrad is not None)
print("Running for %d setps" % nSteps)
if visualizer is not None:
visualizer.publishTensors(input, (128, 128))
# Detect categories
varNoise = torch.randn((input.size(0),
model.config.noiseVectorDim +
model.config.categoryVectorDim),
requires_grad=True, device=model.device)
optimNoise = optim.Adam([varNoise],
betas=[0., 0.99], lr=lr)
noiseOut = model.test(varNoise, getAvG=True, toCPU=False)
if not isinstance(featureExtractors, list):
featureExtractors = [featureExtractors]
if not isinstance(imageTransforms, list):
imageTransforms = [imageTransforms]
nExtractors = len(featureExtractors)
if weights is None:
weights = [1.0 for i in range(nExtractors)]
if len(imageTransforms) != nExtractors:
raise ValueError(
"The number of image transforms should match the number of \
feature extractors")
if len(weights) != nExtractors:
raise ValueError(
"The number of weights should match the number of feature\
extractors")
featuresIn = []
for i in range(nExtractors):
if len(featureExtractors[i]._modules) > 0:
featureExtractors[i] = nn.DataParallel(
featureExtractors[i]).train().to(model.device)
featureExtractors[i].eval()
imageTransforms[i] = nn.DataParallel(
imageTransforms[i]).to(model.device)
featuresIn.append(featureExtractors[i](
imageTransforms[i](input.to(model.device))).detach())
if nevergrad is None:
featureExtractors[i].train()
lr = 1
optimalVector = None
optimalLoss = None
epochStep = int(nSteps / 3)
gradientDecay = 0.1
nImages = input.size(0)
print(f"Generating {nImages} images")
if nevergrad is not None:
optimizers = []
for i in range(nImages):
optimizers += [optimizerlib.registry[nevergrad](
dimension=model.config.noiseVectorDim +
model.config.categoryVectorDim,
budget=nSteps)]
def resetVar(newVal):
newVal.requires_grad = True
print("Updating the optimizer with learning rate : %f" % lr)
varNoise = newVal
optimNoise = optim.Adam([varNoise],
betas=[0., 0.99], lr=lr)
# String's format for loss output
formatCommand = ' '.join(['{:>4}' for x in range(nImages)])
for iter in range(nSteps):
optimNoise.zero_grad()
model.netG.zero_grad()
model.netD.zero_grad()
if randomSearch:
varNoise = torch.randn((nImages,
model.config.noiseVectorDim +
model.config.categoryVectorDim),
device=model.device)
if nevergrad:
inps = []
for i in range(nImages):
inps += [optimizers[i].ask()]
npinps = np.array(inps)
varNoise = torch.tensor(
npinps, dtype=torch.float32, device=model.device)
varNoise.requires_grad = True
varNoise.to(model.device)
noiseOut = model.netG(varNoise)
sumLoss = torch.zeros(nImages, device=model.device)
loss = (((varNoise**2).mean(dim=1) - 1)**2)
sumLoss += loss.view(nImages)
loss.sum(dim=0).backward(retain_graph=True)
for i in range(nExtractors):
featureOut = featureExtractors[i](imageTransforms[i](noiseOut))
diff = ((featuresIn[i] - featureOut)**2)
loss = weights[i] * diff.mean(dim=1)
sumLoss += loss
if not randomSearch:
retainGraph = (lambdaD > 0) or (i != nExtractors - 1)
loss.sum(dim=0).backward(retain_graph=retainGraph)
if lambdaD > 0:
loss = -lambdaD * model.netD(noiseOut)[:, 0]
sumLoss += loss
if not randomSearch:
loss.sum(dim=0).backward()
if nevergrad:
for i in range(nImages):
optimizers[i].tell(inps[i], float(sumLoss[i]))
elif not randomSearch:
optimNoise.step()
if optimalLoss is None:
optimalVector = deepcopy(varNoise)
optimalLoss = sumLoss
else:
optimalVector = torch.where(sumLoss.view(-1, 1) < optimalLoss.view(-1, 1),
varNoise, optimalVector).detach()
optimalLoss = torch.where(sumLoss < optimalLoss,
sumLoss, optimalLoss).detach()
if iter % 100 == 0:
if visualizer is not None:
visualizer.publishTensors(noiseOut.cpu(), (128, 128))
if outPathSave is not None:
index_str = str(int(iter/100))
outPath = os.path.join(outPathSave, index_str + ".jpg")
visualizer.saveTensor(
noiseOut.cpu(),
(noiseOut.size(2), noiseOut.size(3)),
outPath)
print(str(iter) + " : " + formatCommand.format(
*["{:10.6f}".format(sumLoss[i].item())
for i in range(nImages)]))
if iter % epochStep == (epochStep - 1):
lr *= gradientDecay
resetVar(optimalVector)
output = model.test(optimalVector, getAvG=True, toCPU=True).detach()
if visualizer is not None:
visualizer.publishTensors(
output.cpu(), (output.size(2), output.size(3)))
print("optimal losses : " + formatCommand.format(
*["{:10.6f}".format(optimalLoss[i].item())
for i in range(nImages)]))
return output, optimalVector, optimalLoss