def predict_fn()

in code/deploy_ei.py [0:0]


def predict_fn(input_data, model):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    input_id, input_mask = input_data
    input_id = input_id.to(device)
    input_mask = input_mask.to(device)
    
    with torch.no_grad():
        try:
            with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
                print("==================== using elastic inference ====================")
                y = model(input_id, attention_mask=input_mask)[0]
        except TypeError:
            y = model(input_id, attention_mask=input_mask)[0]
        
    print("==================== inference result =======================")
    print(y)
    return y