in models/eval/inspirational_generation.py [0:0]
def test(parser, visualisation=None):
parser = updateParser(parser)
kwargs = vars(parser.parse_args())
# Parameters
name = getVal(kwargs, "name", None)
if name is None:
raise ValueError("You need to input a name")
module = getVal(kwargs, "module", None)
if module is None:
raise ValueError("You need to input a module")
imgPath = getVal(kwargs, "inputImage", None)
if imgPath is None:
raise ValueError("You need to input an image path")
scale = getVal(kwargs, "scale", None)
iter = getVal(kwargs, "iter", None)
nRuns = getVal(kwargs, "nRuns", 1)
checkPointDir = os.path.join(kwargs["dir"], name)
checkpointData = getLastCheckPoint(checkPointDir,
name,
scale=scale,
iter=iter)
weights = getVal(kwargs, 'weights', None)
if checkpointData is None:
raise FileNotFoundError(
"No checkpoint found for model " + str(name) + " at directory "
+ str(checkPointDir) + ' cwd=' + str(os.getcwd()))
modelConfig, pathModel, _ = checkpointData
keysLabels = None
with open(modelConfig, 'rb') as file:
keysLabels = json.load(file)["attribKeysOrder"]
if keysLabels is None:
keysLabels = {}
packageStr, modelTypeStr = getNameAndPackage(module)
modelType = loadmodule(packageStr, modelTypeStr)
visualizer = GANVisualizer(
pathModel, modelConfig, modelType, visualisation)
# Load the image
targetSize = visualizer.model.getSize()
baseTransform = standardTransform(targetSize)
img = pil_loader(imgPath)
input = baseTransform(img)
input = input.view(1, input.size(0), input.size(1), input.size(2))
pathsModel = getVal(kwargs, "featureExtractor", None)
featureExtractors = []
imgTransforms = []
if weights is not None:
if pathsModel is None or len(pathsModel) != len(weights):
raise AttributeError(
"The number of weights must match the number of models")
if pathsModel is not None:
for path in pathsModel:
if path == "id":
featureExtractor = IDModule()
imgTransform = IDModule()
else:
featureExtractor, mean, std = buildFeatureExtractor(
path, resetGrad=True)
imgTransform = FeatureTransform(mean, std, size=kwargs["size"])
featureExtractors.append(featureExtractor)
imgTransforms.append(imgTransform)
else:
featureExtractors = IDModule()
imgTransforms = IDModule()
basePath = os.path.splitext(imgPath)[0] + "_" + kwargs['suffix']
if not os.path.isdir(basePath):
os.mkdir(basePath)
basePath = os.path.join(basePath, os.path.basename(basePath))
print("All results will be saved in " + basePath)
outDictData = {}
outPathDescent = None
fullInputs = torch.cat([input for x in range(nRuns)], dim=0)
if kwargs['save_descent']:
outPathDescent = os.path.join(
os.path.dirname(basePath), "descent")
if not os.path.isdir(outPathDescent):
os.mkdir(outPathDescent)
img, outVectors, loss = gradientDescentOnInput(visualizer.model,
fullInputs,
featureExtractors,
imgTransforms,
visualizer=visualisation,
lambdaD=kwargs['lambdaD'],
nSteps=kwargs['nSteps'],
weights=weights,
randomSearch=kwargs['random_search'],
nevergrad=kwargs['nevergrad'],
lr=kwargs['learningRate'],
outPathSave=outPathDescent)
pathVectors = basePath + "vector.pt"
torch.save(outVectors, open(pathVectors, 'wb'))
path = basePath + ".jpg"
visualisation.saveTensor(img, (img.size(2), img.size(3)), path)
outDictData[os.path.splitext(os.path.basename(path))[0]] = \
[x.item() for x in loss]
outVectors = outVectors.view(outVectors.size(0), -1)
outVectors *= torch.rsqrt((outVectors**2).mean(dim=1, keepdim=True))
barycenter = outVectors.mean(dim=0)
barycenter *= torch.rsqrt((barycenter**2).mean())
meanAngles = (outVectors * barycenter).mean(dim=1)
meanDist = torch.sqrt(((barycenter-outVectors)**2).mean(dim=1)).mean(dim=0)
outDictData["Barycenter"] = {"meanDist": meanDist.item(),
"stdAngles": meanAngles.std().item(),
"meanAngles": meanAngles.mean().item()}
path = basePath + "_data.json"
outDictData["kwargs"] = kwargs
with open(path, 'w') as file:
json.dump(outDictData, file, indent=2)
pathVectors = basePath + "vectors.pt"
torch.save(outVectors, open(pathVectors, 'wb'))