in models/trainer/gan_trainer.py [0:0]
def loadSavedTraining(self,
pathModel,
pathTrainConfig,
pathTmpConfig,
loadGOnly=False,
loadDOnly=False,
finetune=False):
r"""
Load a given checkpoint.
Args:
- pathModel (string): path to the file containing the model
structure (.pt)
- pathTrainConfig (string): path to the reference configuration
file of the training. WARNING: this
file must be compatible with the one
pointed by pathModel
- pathTmpConfig (string): path to the temporary file describing the
state of the training when the checkpoint
was saved. WARNING: this file must be
compatible with the one pointed by
pathModel
"""
# Load the temp configuration
tmpPathLossLog = None
tmpConfig = {}
if pathTmpConfig is not None:
tmpConfig = json.load(open(pathTmpConfig, 'rb'))
self.startScale = tmpConfig["scale"]
self.startIter = tmpConfig["iter"]
self.runningLoss = tmpConfig.get("runningLoss", {})
tmpPathLossLog = tmpConfig.get("lossLog", None)
if tmpPathLossLog is None:
self.lossProfile = [
{"iter": [], "scale": self.startScale}]
elif not os.path.isfile(tmpPathLossLog):
print("WARNING : couldn't find the loss logs at " +
tmpPathLossLog + " resetting the losses")
self.lossProfile = [
{"iter": [], "scale": self.startScale}]
else:
self.lossProfile = pkl.load(open(tmpPathLossLog, 'rb'))
self.lossProfile = self.lossProfile[:(self.startScale + 1)]
if self.lossProfile[-1]["iter"][-1] > self.startIter:
indexStop = next(x[0] for x in enumerate(self.lossProfile[-1]["iter"])
if x[1] > self.startIter)
self.lossProfile[-1]["iter"] = self.lossProfile[-1]["iter"][:indexStop]
for item in self.lossProfile[-1]:
if isinstance(self.lossProfile[-1][item], list):
self.lossProfile[-1][item] = \
self.lossProfile[-1][item][:indexStop]
# Read the training configuration
if not finetune:
trainConfig = json.load(open(pathTrainConfig, 'rb'))
self.readTrainConfig(trainConfig)
# Re-initialize the model
self.initModel()
self.model.load(pathModel,
loadG=not loadDOnly,
loadD=not loadGOnly,
finetuning=finetune)
# Build retrieve the reference vectors
self.refVectorPath = tmpConfig.get("refVectors", None)
if self.refVectorPath is None:
self.refVectorVisualization, self.refVectorLabels = \
self.model.buildNoiseData(self.nDataVisualization)
elif not os.path.isfile(self.refVectorPath):
print("WARNING : no file found at " + self.refVectorPath
+ " building new reference vectors")
self.refVectorVisualization, self.refVectorLabels = \
self.model.buildNoiseData(self.nDataVisualization)
else:
self.refVectorVisualization = torch.load(
open(self.refVectorPath, 'rb'))