in sagemaker_notebook_instance/containers/relationship_extraction/package/inference.py [0:0]
def predict_fn(request, model_assets):
encoding = model_assets['tokenizer'].encode(
sequence=request['sequence'],
entity_one_start=request['entity_one_start'],
entity_one_end=request['entity_one_end'],
entity_two_start=request['entity_two_start'],
entity_two_end=request['entity_two_end']
)
token_ids = torch.tensor(encoding['ids']).unsqueeze(0)
attention_mask = torch.tensor(encoding['attention_mask']).unsqueeze(0)
logits = model_assets['model'](
token_ids=token_ids,
attention_mask=attention_mask
)
pred_pt = torch.argmax(logits, dim=1)
pred_py = pred_pt[0].item()
output = {
'id': pred_py,
'str': model_assets['label_encoder'].id_to_str(pred_py)
}
return output