def predict_fn()

in src/similarity/batch.py [0:0]


def predict_fn(input_data, model):
    logger.info('Making prediction.')
    
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(DEVICE)
    input_data = input_data.to(DEVICE)
    
    logger.info(input_data.shape)
    
    img1 = input_data.narrow(0,0,1)
    img2 = input_data.narrow(0,1,input_data.shape[0]-1)
    
    print(img1.shape)
    print(img2.shape)
        
    logger.info(img1.shape)
    logger.info(img2.shape)
    distances = model.forward(img1,img2).tolist()

    logger.info(distances)

    return distances