in models/metrics/nn_score.py [0:0]
def buildFeatureMaker(pathDB,
pathTrainAttrib,
pathValAttrib,
specificAttrib=None,
visualisation=None):
# Parameters
batchSize = 16
nEpochs = 3
learningRate = 1e-4
beta1 = 0.9
beta2 = 0.99
device = torch.device("cuda:0")
n_devices = torch.cuda.device_count()
# Model
resnet18 = models.resnet18(pretrained=True)
resnet18.train()
# Dataset
size = 224
transformList = [Transforms.Resize((size, size)),
Transforms.RandomHorizontalFlip(),
Transforms.ToTensor(),
Transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
transform = Transforms.Compose(transformList)
dataset = AttribDataset(pathDB, transform=transform,
attribDictPath=pathTrainAttrib,
specificAttrib=specificAttrib,
mimicImageFolder=False)
validationDataset = AttribDataset(pathDB, transform=transform,
attribDictPath=pathValAttrib,
specificAttrib=specificAttrib,
mimicImageFolder=False)
print("%d training images detected, %d validation images detected"
% (len(dataset), len(validationDataset)))
# Optimization
optimizer = torch.optim.Adam(resnet18.parameters(),
betas=[beta1, beta2],
lr=learningRate)
lossMode = ACGANCriterion(dataset.getKeyOrders())
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, lossMode.getInputDim())
resnet18 = nn.DataParallel(resnet18).to(device)
# Visualization data
lossTrain = []
lossVal = []
iterList = []
tokenTrain = None
tokenVal = None
step = 0
tmpLoss = 0
for epoch in range(nEpochs):
loader = torch.utils.data.DataLoader(dataset,
batch_size=batchSize,
shuffle=True,
num_workers=n_devices)
for iter, data in enumerate(loader):
optimizer.zero_grad()
inputs_real, labels = data
inputs_real = inputs_real.to(device)
labels = labels.to(device)
predictedLabels = resnet18(inputs_real)
loss = lossMode.getLoss(predictedLabels, labels)
tmpLoss += loss.item()
loss.backward()
optimizer.step()
if step % 100 == 0 and visualisation is not None:
divisor = 100
if step == 0:
divisor = 1
lossTrain.append(tmpLoss / divisor)
iterList.append(step)
tokenTrain = visualisation.publishLinePlot([('lossTrain', lossTrain)], iterList,
name="Loss train",
window_token=tokenTrain,
env="main")
validationLoader = torch.utils.data.DataLoader(validationDataset,
batch_size=batchSize,
shuffle=True,
num_workers=n_devices)
resnet18.eval()
lossEval = 0
i = 0
for valData in validationLoader:
inputs_real, labels = data
inputs_real = inputs_real.to(device)
labels = labels.to(device)
lossEval += lossMode.getLoss(predictedLabels,
labels).item()
i += 1
if i == 100:
break
lossEval /= i
lossVal.append(lossEval)
tokenVal = visualisation.publishLinePlot([('lossValidation', lossVal)], iterList,
name="Loss validation",
window_token=tokenVal,
env="main")
resnet18.train()
print("[%5d ; %5d ] Loss train : %f ; Loss validation %f"
% (epoch, iter, tmpLoss / divisor, lossEval))
tmpLoss = 0
step += 1
return resnet18.module