def train_sim_model()

in src/similarity/siamese2.py [0:0]


def train_sim_model(model, train_dl, optimizer, num_epochs= args.epochs):
    
    try :
        since = time.time()
        best_loss = 1000.0
        model = model.to(DEVICE)

        ntuples = train_dl.dataset.amplified_size()
        
        for epoch in range(num_epochs):

            logger.info('\n Epoch {}/{}'.format(epoch, num_epochs - 1))
            logger.info('-' * 10)

            # Each epoch has a training and validation phase
            model.train()  # Set model to training mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for data in train_dl:

                img1 = data['img1'].to(DEVICE)
                img1 = img1.view(-1,img1.shape[-3],img1.shape[-2],img1.shape[-1])
                
                img2 = data['img2'].to(DEVICE)
                img2 = img2.view(-1,img2.shape[-3],img2.shape[-2],img2.shape[-1])
                
                labels = data['labels'].to(DEVICE).float()
                labels = labels.view(-1)

                # zero the parameter gradients
                optimizer.zero_grad()

                distance = model.forward(img1,img2)

                loss = contrastive_loss(distance, labels)
                loss.backward()
                optimizer.step()

                # statistics
                predictions = (torch.abs(distance - labels) < args.similarity_margin).int()
                running_loss += loss.item()
                running_corrects += torch.sum(predictions)

            print()

            BEST_MODEL_METRIC['train-loss'] = running_loss / ntuples
            BEST_MODEL_METRIC['train-acc'] = running_corrects.double() / ntuples

            logger.info('Training set: Average loss: {:.8f}, Average acc: {:.8f} \n'
                        .format(BEST_MODEL_METRIC['train-loss'], BEST_MODEL_METRIC['train-acc']))

            # checkpoint the best model
            if  BEST_MODEL_METRIC[args.best_model_metric] < best_loss:
                best_loss = BEST_MODEL_METRIC[args.best_model_metric]

                logger.info('Saving the best model: {}'.format(best_loss))
                with open(CHECKPOINT_PATH, 'wb') as f:
                    torch.save(model.state_dict(), f)
                with open(CHECKPOINT_STATE_PATH, 'w') as f:
                    f.write('epoch {:3d} | lr: {:5.2f} | loss {:.8f}'
                            .format(epoch, args.learning_rate, best_loss))

        time_elapsed = time.time() - since
        logger.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        logger.info('Best Loss: {:8f}'.format(best_loss))

        # Load the best saved model.
        with open(CHECKPOINT_PATH, 'rb') as f:
            model.load_state_dict(torch.load(f))

    except: 
        
        # Load the best saved model.
        with open(CHECKPOINT_PATH, 'rb') as f:
            model.load_state_dict(torch.load(f))
        
        if model != None :
            # Move the best model to cpu and resave it
            with open(MODEL_PATH, 'wb') as f:
                torch.save(model.cpu().state_dict(), f)
                   
    return model