def predict_fn()

in sagemaker_notebook_instance/containers/relationship_extraction/package/inference.py [0:0]


def predict_fn(request, model_assets):
    encoding = model_assets['tokenizer'].encode(
        sequence=request['sequence'],
        entity_one_start=request['entity_one_start'],
        entity_one_end=request['entity_one_end'],
        entity_two_start=request['entity_two_start'],
        entity_two_end=request['entity_two_end']
    )

    token_ids = torch.tensor(encoding['ids']).unsqueeze(0)
    attention_mask = torch.tensor(encoding['attention_mask']).unsqueeze(0)
    logits = model_assets['model'](
        token_ids=token_ids,
        attention_mask=attention_mask
    )
    pred_pt = torch.argmax(logits, dim=1)
    pred_py = pred_pt[0].item()
    output = {
        'id': pred_py,
        'str': model_assets['label_encoder'].id_to_str(pred_py)
    }
    return output