in sagemaker/24_train_bloom_peft_lora/scripts/inference.py [0:0]
def predict_fn(data, model_and_tokenizer):
# unpack model and tokenizer
model, tokenizer = model_and_tokenizer
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# preprocess
input_ids = tokenizer(inputs, return_tensors="pt").input_ids.to(model.device)
# pass inputs with all kwargs in data
if parameters is not None:
outputs = model.generate(input_ids, **parameters)
else:
outputs = model.generate(input_ids)
# postprocess the prediction
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
return [{"generated_text": prediction}]