def train()

in models/trainer/progressive_gan_trainer.py [0:0]


    def train(self):
        r"""
        Launch the training. This one will stop if a divergent behavior is
        detected.

        Returns:

            - True if the training completed
            - False if the training was interrupted due to a divergent behavior
        """

        n_scales = len(self.modelConfig.depthScales)

        if self.checkPointDir is not None:
            pathBaseConfig = os.path.join(self.checkPointDir, self.modelLabel
                                          + "_train_config.json")
            self.saveBaseConfig(pathBaseConfig)

        for scale in range(self.startScale, n_scales):

            self.updateDatasetForScale(scale)

            while scale >= len(self.lossProfile):
                self.lossProfile.append(
                    {"scale": scale, "iter": []})

            dbLoader = self.getDBLoader(scale)
            sizeDB = len(dbLoader)

            shiftIter = 0
            if self.startIter > 0:
                shiftIter = self.startIter
                self.startIter = 0

            shiftAlpha = 0
            while shiftAlpha < len(self.modelConfig.iterAlphaJump[scale]) and \
                    self.modelConfig.iterAlphaJump[scale][shiftAlpha] < shiftIter:
                shiftAlpha += 1

            while shiftIter < self.modelConfig.maxIterAtScale[scale]:

                self.indexJumpAlpha = shiftAlpha
                status = self.trainOnEpoch(dbLoader, scale,
                                           shiftIter=shiftIter,
                                           maxIter=self.modelConfig.maxIterAtScale[scale])

                if not status:
                    return False

                shiftIter += sizeDB
                while shiftAlpha < len(self.modelConfig.iterAlphaJump[scale]) and \
                        self.modelConfig.iterAlphaJump[scale][shiftAlpha] < shiftIter:
                    shiftAlpha += 1

            # Save a checkpoint
            if self.checkPointDir is not None:
                realIter = min(
                    shiftIter, self.modelConfig.maxIterAtScale[scale])
                label = self.modelLabel + ("_s%d_i%d" %
                                           (scale, realIter))
                self.saveCheckpoint(self.checkPointDir,
                                    label, scale, realIter)
            if scale == n_scales - 1:
                break

            self.model.addScale(self.modelConfig.depthScales[scale + 1])

        self.startScale = n_scales
        self.startIter = self.modelConfig.maxIterAtScale[-1]
        return True