in models/trainer/gan_trainer.py [0:0]
def __init__(self,
pathdb,
useGPU=True,
visualisation=None,
lossIterEvaluation=200,
saveIter=5000,
checkPointDir=None,
modelLabel="GAN",
config=None,
pathAttribDict=None,
selectedAttributes=None,
imagefolderDataset=False,
ignoreAttribs=False,
pathPartition=None,
partitionValue=None):
r"""
Args:
- pathdb (string): path to the directorty containing the image
dataset.
- useGPU (bool): set to True if you want to use the available GPUs
for the training procedure
- visualisation (module): if not None, a visualisation module to
follow the evolution of the training
- lossIterEvaluation (int): size of the interval on which the
model's loss will be evaluated
- saveIter (int): frequency at which at checkpoint should be saved
(relevant only if modelLabel != None)
- checkPointDir (string): if not None, directory where the
checkpoints should be saved
- modelLabel (string): name of the model
- config (dictionary): configuration dictionnary.
for all the possible options
- pathAttribDict (string): path to the attribute dictionary giving
the labels of the dataset
- selectedAttributes (list): if not None, consider only the listed
attributes for labelling
- imagefolderDataset (bool): set to true if the data are stored in
the fashion of a
torchvision.datasests.ImageFolderDataset
object
- ignoreAttribs (bool): set to True if the input attrib dict should
only be used as a filter on image's names
- pathPartition (string): if only a subset of the original dataset
should be used
- pathValue (string): partition value
"""
# Parameters
# Training dataset
self.path_db = pathdb
self.pathPartition = pathPartition
self.partitionValue = partitionValue
if config is None:
config = {}
# Load the training configuration
self.readTrainConfig(config)
# Initialize the model
self.useGPU = useGPU
if not self.useGPU:
self.numWorkers = 1
self.pathAttribDict = pathAttribDict
self.selectedAttributes = selectedAttributes
self.imagefolderDataset = imagefolderDataset
self.modelConfig.attribKeysOrder = None
if (not ignoreAttribs) and \
(self.pathAttribDict is not None or self.imagefolderDataset):
self.modelConfig.attribKeysOrder = self.getDataset(
0, size=10).getKeyOrders()
print("AC-GAN classes : ")
print(self.modelConfig.attribKeysOrder)
print("")
# Intern state
self.runningLoss = {}
self.startScale = 0
self.startIter = 0
self.lossProfile = []
self.initModel()
print("%d images detected" % int(len(self.getDataset(0, size=10))))
# Visualization ?
self.visualisation = visualisation
self.tokenWindowFake = None
self.tokenWindowFakeSmooth = None
self.tokenWindowReal = None
self.tokenWindowLosses = None
self.refVectorPath = None
self.nDataVisualization = 16
self.refVectorVisualization, self.refVectorLabels = \
self.model.buildNoiseData(self.nDataVisualization)
# Checkpoints ?
self.checkPointDir = checkPointDir
self.modelLabel = modelLabel
self.saveIter = saveIter
self.pathLossLog = None
if self.checkPointDir is not None:
self.pathLossLog = os.path.abspath(os.path.join(self.checkPointDir,
self.modelLabel
+ '_losses.pkl'))
self.pathRefVector = os.path.abspath(os.path.join(self.checkPointDir,
self.modelLabel
+ '_refVectors.pt'))
# Loss printing
self.lossIterEvaluation = lossIterEvaluation