in models/base_GAN.py [0:0]
def optimizeParameters(self, input_batch, inputLabels=None):
r"""
Update the discrimator D using the given "real" inputs.
Args:
input (torch.tensor): input batch of real data
inputLabels (torch.tensor): labels of the real data
"""
allLosses = {}
# Retrieve the input data
self.real_input, self.realLabels = input_batch.to(self.device), None
if self.config.attribKeysOrder is not None:
self.realLabels = inputLabels.to(self.device)
n_samples = self.real_input.size()[0]
# Update the discriminator
self.optimizerD.zero_grad()
# #1 Real data
predRealD = self.netD(self.real_input, False)
# Classification criterion
allLosses["lossD_classif"] = \
self.classificationPenalty(predRealD,
self.realLabels,
self.config.weightConditionD,
backward=True)
lossD = self.lossCriterion.getCriterion(predRealD, True)
allLosses["lossD_real"] = lossD.item()
# #2 Fake data
inputLatent, targetRandCat = self.buildNoiseData(n_samples)
predFakeG = self.netG(inputLatent).detach()
predFakeD = self.netD(predFakeG, False)
lossDFake = self.lossCriterion.getCriterion(predFakeD, False)
allLosses["lossD_fake"] = lossDFake.item()
lossD += lossDFake
# #3 WGANGP gradient loss
if self.config.lambdaGP > 0:
allLosses["lossD_Grad"] = WGANGPGradientPenalty(self.real_input,
predFakeG,
self.netD,
self.config.lambdaGP,
backward=True)
# #4 Epsilon loss
if self.config.epsilonD > 0:
lossEpsilon = (predRealD[:, 0] ** 2).sum() * self.config.epsilonD
lossD += lossEpsilon
allLosses["lossD_Epsilon"] = lossEpsilon.item()
# # 5 Logistic gradient loss
if self.config.logisticGradReal > 0:
allLosses["lossD_logistic"] = \
logisticGradientPenalty(self.real_input, self.netD,
self.config.logisticGradReal,
backward=True)
lossD.backward(retain_graph=True)
finiteCheck(self.getOriginalD().parameters())
self.optimizerD.step()
# Logs
lossD = 0
for key, val in allLosses.items():
if key.find("lossD") == 0:
lossD += val
allLosses["lossD"] = lossD
# Update the generator
self.optimizerG.zero_grad()
self.optimizerD.zero_grad()
# #1 Image generation
inputNoise, targetCatNoise = self.buildNoiseData(n_samples)
predFakeG = self.netG(inputNoise)
# #2 Status evaluation
predFakeD, phiGFake = self.netD(predFakeG, True)
# #2 Classification criterion
allLosses["lossG_classif"] = \
self.classificationPenalty(predFakeD,
targetCatNoise,
self.config.weightConditionG,
backward=True)
# #3 GAN criterion
lossGFake = self.lossCriterion.getCriterion(predFakeD, True)
allLosses["lossG_fake"] = lossGFake.item()
lossGFake.backward(retain_graph=True)
if self.config.GDPP:
_, phiDReal = self.netD.forward(self.real_input, True)
allLosses["lossG_GDPP"] = GDPPLoss(phiDReal, phiGFake,
backward=True)
finiteCheck(self.getOriginalG().parameters())
self.optimizerG.step()
lossG = 0
for key, val in allLosses.items():
if key.find("lossG") == 0:
lossG += val
allLosses["lossG"] = lossG
# Update the moving average if relevant
for p, avg_p in zip(self.getOriginalG().parameters(),
self.getOriginalAvgG().parameters()):
avg_p.mul_(0.999).add_(0.001, p.data)
return allLosses