def predict_fn()

in src/similarity/inference.py [0:0]


def predict_fn(input_data, model):
    logger.info('Making prediction.')
    
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(DEVICE)
    input_data = input_data.to(DEVICE)
    
    logger.info(input_data.shape)
    images = torch.split(input_data, int(input_data.shape[0]/2))
    
    img1 = images[0].unsqueeze_(0)
    img2 = images[1].unsqueeze_(0)
    logger.info(img1.shape)
    logger.info(img2.shape)
    logger.info(print(model))
    
    distance = model.forward(img1,img2)[0].item()
    logger.info(distance)
    return {"similarity": distance}