in models/base_GAN.py [0:0]
def __init__(self,
dimLatentVector,
dimOutput=3,
useGPU=True,
baseLearningRate=0.001,
lossMode='WGANGP',
attribKeysOrder=None,
weightConditionD=0.0,
weightConditionG=0.0,
logisticGradReal=0.0,
lambdaGP=0.,
epsilonD=0.,
GDPP=False,
**kwargs):
r"""
Args:
dimLatentVector (int): dimension of the latent vector in the model
dimOutput (int): number of channels of the output image
useGPU (bool): set to true if the computation should be distribued
in the availanle GPUs
baseLearningRate (float): target learning rate.
lossMode (string): loss used by the model. Must be one of the
following options
* 'MSE' : mean square loss.
* 'DCGAN': cross entropy loss
* 'WGANGP': https://arxiv.org/pdf/1704.00028.pdf
* 'Logistic': https://arxiv.org/pdf/1801.04406.pdf
attribKeysOrder (dict): if not None, activate AC-GAN. In this case,
both the generator and the discrimator are
trained on abelled data.
weightConditionD (float): in AC-GAN, weight of the classification
loss applied to the discriminator
weightConditionG (float): in AC-GAN, weight of the classification
loss applied to the generator
logisticGradReal (float): gradient penalty for the logistic loss
lambdaGP (float): if > 0, weight of the gradient penalty (WGANGP)
epsilonD (float): if > 0, penalty on |D(X)|**2
GDPP (bool): if true activate GDPP loss https://arxiv.org/abs/1812.00068
"""
if lossMode not in ['MSE', 'WGANGP', 'DCGAN', 'Logistic']:
raise ValueError(
"lossMode should be one of the following : ['MSE', 'WGANGP', \
'DCGAN', 'Logistic']")
if 'config' not in vars(self):
self.config = BaseConfig()
if 'trainTmp' not in vars(self):
self.trainTmp = BaseConfig()
self.useGPU = useGPU and torch.cuda.is_available()
if self.useGPU:
self.device = torch.device("cuda:0")
self.n_devices = torch.cuda.device_count()
else:
self.device = torch.device("cpu")
self.n_devices = 1
# Latent vector dimension
self.config.noiseVectorDim = dimLatentVector
# Output image dimension
self.config.dimOutput = dimOutput
# Actual learning rate
self.config.learningRate = baseLearningRate
# AC-GAN ?
self.config.attribKeysOrder = deepcopy(attribKeysOrder)
self.config.categoryVectorDim = 0
self.config.weightConditionG = weightConditionG
self.config.weightConditionD = weightConditionD
self.ClassificationCriterion = None
self.initializeClassificationCriterion()
# GDPP
self.config.GDPP = GDPP
self.config.latentVectorDim = self.config.noiseVectorDim \
+ self.config.categoryVectorDim
# Loss criterion
self.config.lossCriterion = lossMode
self.lossCriterion = getattr(
base_loss_criterions, lossMode)(self.device)
# WGAN-GP
self.config.lambdaGP = lambdaGP
# Weight on D's output
self.config.epsilonD = epsilonD
# Initialize the generator and the discriminator
self.netD = self.getNetD()
self.netG = self.getNetG()
# Move the networks to the gpu
self.updateSolversDevice()
# Logistic loss
self.config.logisticGradReal = logisticGradReal