def load_sagemaker_model_artifact()

in utils.py [0:0]


def load_sagemaker_model_artifact(s3_bucket: str, key: str) -> dict:
    """Load a PyTorch model artifact (model.tar.gz) produced by a SageMaker
    Training job.
    Args:
        s3_bucket: str, s3 bucket name (s3://bucket_name)
        key: object key: path to model.tar.gz from within the bucket
    Returns:
        state_dict: dict representing the PyTorch checkpoint
    """
    # load the s3 object
    s3 = boto3.client("s3")
    obj = s3.get_object(Bucket=s3_bucket, Key=key)
    # read into memory
    model_artifact = BytesIO(obj["Body"].read())
    # parse out the state dict from the tar.gz file
    tar = tarfile.open(fileobj=model_artifact)
    for member in tar.getmembers():
        pth = tar.extractfile(member).read()

    state_dict = torch.load(BytesIO(pth), map_location=torch.device("cpu"))
    return state_dict