def predict_fn()

in sagemaker/24_train_bloom_peft_lora/scripts/inference.py [0:0]


def predict_fn(data, model_and_tokenizer):
    # unpack model and tokenizer
    model, tokenizer = model_and_tokenizer

    # process input
    inputs = data.pop("inputs", data)
    parameters = data.pop("parameters", None)

    # preprocess
    input_ids = tokenizer(inputs, return_tensors="pt").input_ids.to(model.device)

    # pass inputs with all kwargs in data
    if parameters is not None:
        outputs = model.generate(input_ids, **parameters)
    else:
        outputs = model.generate(input_ids)

    # postprocess the prediction
    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return [{"generated_text": prediction}]