def model_fn()

in sagemaker_notebook_instance/containers/summarization/entry_point.py [0:0]


def model_fn(model_dir):
    session = sagemaker.Session()
    bucket = os.getenv("MODEL_ASSETS_S3_BUCKET")
    prefix = os.getenv("MODEL_ASSETS_S3_PREFIX")
    session.download_data(path=model_dir, bucket=bucket, key_prefix=prefix)
    model = AutoModelWithLMHead.from_pretrained("t5-base", cache_dir=model_dir)
    tokenizer = AutoTokenizer.from_pretrained("t5-base", cache_dir=model_dir)
    summarizer = pipeline(
        task="summarization",
        model=model,
        tokenizer=tokenizer
    )
    model_assets = {
        "summarizer": summarizer
    }
    return model_assets