in utils.py [0:0]
def load_sagemaker_model_artifact(s3_bucket: str, key: str) -> dict:
"""Load a PyTorch model artifact (model.tar.gz) produced by a SageMaker
Training job.
Args:
s3_bucket: str, s3 bucket name (s3://bucket_name)
key: object key: path to model.tar.gz from within the bucket
Returns:
state_dict: dict representing the PyTorch checkpoint
"""
# load the s3 object
s3 = boto3.client("s3")
obj = s3.get_object(Bucket=s3_bucket, Key=key)
# read into memory
model_artifact = BytesIO(obj["Body"].read())
# parse out the state dict from the tar.gz file
tar = tarfile.open(fileobj=model_artifact)
for member in tar.getmembers():
pth = tar.extractfile(member).read()
state_dict = torch.load(BytesIO(pth), map_location=torch.device("cpu"))
return state_dict