in src/similarity/siamese.py [0:0]
def train_sim_model(model, train_dl, test_dl, optimizer, num_epochs= args.epochs):
try :
since = time.time()
best_loss = 1000.0
model = model.to(DEVICE)
for epoch in range(num_epochs):
logger.info('Epoch {}/{}'.format(epoch, num_epochs - 1))
logger.info('-' * 10)
# Each epoch has a training and validation phase
model.train() # Set model to training mode
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for data in train_dl:
img1 = data['img1'].to(DEVICE)
img1 = img1.view(-1,img1.shape[-3],img1.shape[-2],img1.shape[-1])
img2 = data['img2'].to(DEVICE)
img2 = img2.view(-1,img2.shape[-3],img2.shape[-2],img2.shape[-1])
labels = data['label'].to(DEVICE).float()
labels = labels.view(-1)
# zero the parameter gradients
optimizer.zero_grad()
distance = model.forward(img1,img2)
loss = contrastive_loss(distance, labels)
loss.backward()
optimizer.step()
# statistics
predictions = (torch.abs(distance - labels) < args.similarity_margin).int()
running_loss += loss.item()
running_corrects += torch.sum(predictions)
print()
BEST_MODEL_METRIC['train-loss'] = running_loss / len(train_dl.dataset)
BEST_MODEL_METRIC['train-acc'] = running_corrects.double() / len(train_dl.dataset)
logger.info('Training set: Average loss: {:.8f}, Average acc: {:.8f} \n'
.format(BEST_MODEL_METRIC['train-loss'], BEST_MODEL_METRIC['train-acc']))
BEST_MODEL_METRIC['test-loss'], BEST_MODEL_METRIC['test-acc'] = test_model(model,test_dl)
# checkpoint the best model
if BEST_MODEL_METRIC[args.best_model_metric] < best_loss:
best_loss = BEST_MODEL_METRIC[args.best_model_metric]
logger.info('Saving the best model: {}'.format(best_loss))
with open(CHECKPOINT_PATH, 'wb') as f:
torch.save(model.state_dict(), f)
with open(CHECKPOINT_STATE_PATH, 'w') as f:
f.write('epoch {:3d} | lr: {:5.2f} | loss {:.8f}'
.format(epoch, args.learning_rate, best_loss))
time_elapsed = time.time() - since
logger.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
logger.info('Best Loss: {:8f}'.format(best_loss))
# Load the best saved model.
with open(CHECKPOINT_PATH, 'rb') as f:
model.load_state_dict(torch.load(f))
except:
# Load the best saved model.
with open(CHECKPOINT_PATH, 'rb') as f:
model.load_state_dict(torch.load(f))
if model != None :
# Move the best model to cpu and resave it
with open(MODEL_PATH, 'wb') as f:
torch.save(model.cpu().state_dict(), f)
return model