def model_fn()

in sagemaker_notebook_instance/containers/relationship_extraction/package/inference.py [0:0]


def model_fn(model_dir):
    tokenizer = RelationshipTokenizer.from_file(
        file_path=Path(model_dir, 'tokenizer.json'),
        contains_entity_tokens=True
    )
    label_encoder = LabelEncoder.from_file(
        file_path=Path(model_dir, 'label_encoder.json')
    )
    model = RelationshipEncoderLightningModule.load_from_checkpoint(
        str(Path(model_dir, 'model.ckpt')),
        tokenizer=tokenizer,
        label_encoder=label_encoder
    )
    model.eval()
    print(model.model.linear.weight[:5, :5])
    print(model.model.text_encoder.encoder.layer[-1].output.dense.weight[:5, :5])
    
    model_assets = {
        'tokenizer': tokenizer,
        'label_encoder': label_encoder,
        'model': model
    }
    return model_assets