packages/@aws-cdk/aws-sagemaker-alpha/lib/model-data.ts (50 lines of code) (raw):

import * as s3 from 'aws-cdk-lib/aws-s3'; import * as assets from 'aws-cdk-lib/aws-s3-assets'; import { Construct } from 'constructs'; import { IModel } from './model'; import { hashcode } from './private/util'; // The only supported extension for local asset model data // https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-properties-sagemaker-model-containerdefinition.html#cfn-sagemaker-model-containerdefinition-modeldataurl const ARTIFACT_EXTENSION = '.tar.gz'; /** * The configuration needed to reference model artifacts. */ export interface ModelDataConfig { /** * The S3 path where the model artifacts, which result from model training, are stored. This path * must point to a single gzip compressed tar archive (.tar.gz suffix). */ readonly uri: string; } /** * Model data represents the source of model artifacts, which will ultimately be loaded from an S3 * location. */ export abstract class ModelData { /** * Constructs model data which is already available within S3. * @param bucket The S3 bucket within which the model artifacts are stored * @param objectKey The S3 object key at which the model artifacts are stored */ public static fromBucket(bucket: s3.IBucket, objectKey: string): ModelData { return new S3ModelData(bucket, objectKey); } /** * Constructs model data that will be uploaded to S3 as part of the CDK app deployment. * @param path The local path to a model artifact file as a gzipped tar file * @param options The options to further configure the selected asset */ public static fromAsset(path: string, options: assets.AssetOptions = {}): ModelData { return new AssetModelData(path, options); } /** * This method is invoked by the SageMaker Model construct when it needs to resolve the model * data to a URI. * @param scope The scope within which the model data is resolved * @param model The Model construct performing the URI resolution */ public abstract bind(scope: Construct, model: IModel): ModelDataConfig; } class S3ModelData extends ModelData { constructor(private readonly bucket: s3.IBucket, private readonly objectKey: string) { super(); } public bind(_scope: Construct, model: IModel): ModelDataConfig { this.bucket.grantRead(model); return { uri: this.bucket.urlForObject(this.objectKey), }; } } class AssetModelData extends ModelData { private asset?: assets.Asset; constructor(private readonly path: string, private readonly options: assets.AssetOptions) { super(); if (!path.toLowerCase().endsWith(ARTIFACT_EXTENSION)) { throw new Error(`Asset must be a gzipped tar file with extension ${ARTIFACT_EXTENSION} (${this.path})`); } } public bind(scope: Construct, model: IModel): ModelDataConfig { // Retain the first instantiation of this asset if (!this.asset) { this.asset = new assets.Asset(scope, `ModelData${hashcode(this.path)}`, { path: this.path, ...this.options, }); } this.asset.grantRead(model); return { uri: this.asset.httpUrl, }; } }