in models/trainer/progressive_gan_trainer.py [0:0]
def train(self):
r"""
Launch the training. This one will stop if a divergent behavior is
detected.
Returns:
- True if the training completed
- False if the training was interrupted due to a divergent behavior
"""
n_scales = len(self.modelConfig.depthScales)
if self.checkPointDir is not None:
pathBaseConfig = os.path.join(self.checkPointDir, self.modelLabel
+ "_train_config.json")
self.saveBaseConfig(pathBaseConfig)
for scale in range(self.startScale, n_scales):
self.updateDatasetForScale(scale)
while scale >= len(self.lossProfile):
self.lossProfile.append(
{"scale": scale, "iter": []})
dbLoader = self.getDBLoader(scale)
sizeDB = len(dbLoader)
shiftIter = 0
if self.startIter > 0:
shiftIter = self.startIter
self.startIter = 0
shiftAlpha = 0
while shiftAlpha < len(self.modelConfig.iterAlphaJump[scale]) and \
self.modelConfig.iterAlphaJump[scale][shiftAlpha] < shiftIter:
shiftAlpha += 1
while shiftIter < self.modelConfig.maxIterAtScale[scale]:
self.indexJumpAlpha = shiftAlpha
status = self.trainOnEpoch(dbLoader, scale,
shiftIter=shiftIter,
maxIter=self.modelConfig.maxIterAtScale[scale])
if not status:
return False
shiftIter += sizeDB
while shiftAlpha < len(self.modelConfig.iterAlphaJump[scale]) and \
self.modelConfig.iterAlphaJump[scale][shiftAlpha] < shiftIter:
shiftAlpha += 1
# Save a checkpoint
if self.checkPointDir is not None:
realIter = min(
shiftIter, self.modelConfig.maxIterAtScale[scale])
label = self.modelLabel + ("_s%d_i%d" %
(scale, realIter))
self.saveCheckpoint(self.checkPointDir,
label, scale, realIter)
if scale == n_scales - 1:
break
self.model.addScale(self.modelConfig.depthScales[scale + 1])
self.startScale = n_scales
self.startIter = self.modelConfig.maxIterAtScale[-1]
return True