def loadSavedTraining()

in models/trainer/gan_trainer.py [0:0]


    def loadSavedTraining(self,
                          pathModel,
                          pathTrainConfig,
                          pathTmpConfig,
                          loadGOnly=False,
                          loadDOnly=False,
                          finetune=False):
        r"""
        Load a given checkpoint.

        Args:

            - pathModel (string): path to the file containing the model
                                 structure (.pt)
            - pathTrainConfig (string): path to the reference configuration
                                        file of the training. WARNING: this
                                        file must be compatible with the one
                                        pointed by pathModel
            - pathTmpConfig (string): path to the temporary file describing the
                                      state of the training when the checkpoint
                                      was saved. WARNING: this file must be
                                      compatible with the one pointed by
                                      pathModel
        """

        # Load the temp configuration
        tmpPathLossLog = None
        tmpConfig = {}

        if pathTmpConfig is not None:
            tmpConfig = json.load(open(pathTmpConfig, 'rb'))
            self.startScale = tmpConfig["scale"]
            self.startIter = tmpConfig["iter"]
            self.runningLoss = tmpConfig.get("runningLoss", {})

            tmpPathLossLog = tmpConfig.get("lossLog", None)

        if tmpPathLossLog is None:
            self.lossProfile = [
                {"iter": [], "scale": self.startScale}]
        elif not os.path.isfile(tmpPathLossLog):
            print("WARNING : couldn't find the loss logs at " +
                  tmpPathLossLog + " resetting the losses")
            self.lossProfile = [
                {"iter": [], "scale": self.startScale}]
        else:
            self.lossProfile = pkl.load(open(tmpPathLossLog, 'rb'))
            self.lossProfile = self.lossProfile[:(self.startScale + 1)]

            if self.lossProfile[-1]["iter"][-1] > self.startIter:
                indexStop = next(x[0] for x in enumerate(self.lossProfile[-1]["iter"])
                                 if x[1] > self.startIter)
                self.lossProfile[-1]["iter"] = self.lossProfile[-1]["iter"][:indexStop]

                for item in self.lossProfile[-1]:
                    if isinstance(self.lossProfile[-1][item], list):
                        self.lossProfile[-1][item] = \
                            self.lossProfile[-1][item][:indexStop]

        # Read the training configuration
        if not finetune:
            trainConfig = json.load(open(pathTrainConfig, 'rb'))
            self.readTrainConfig(trainConfig)

        # Re-initialize the model
        self.initModel()
        self.model.load(pathModel,
                        loadG=not loadDOnly,
                        loadD=not loadGOnly,
                        finetuning=finetune)

        # Build retrieve the reference vectors
        self.refVectorPath = tmpConfig.get("refVectors", None)
        if self.refVectorPath is None:
            self.refVectorVisualization, self.refVectorLabels = \
                self.model.buildNoiseData(self.nDataVisualization)
        elif not os.path.isfile(self.refVectorPath):
            print("WARNING : no file found at " + self.refVectorPath
                  + " building new reference vectors")
            self.refVectorVisualization, self.refVectorLabels = \
                self.model.buildNoiseData(self.nDataVisualization)
        else:
            self.refVectorVisualization = torch.load(
                open(self.refVectorPath, 'rb'))