packages/blueprints/gen-ai-chatbot/static-assets/chatbot-genai-cdk/lib/constructs/websocket.ts (138 lines of code) (raw):

import { Construct } from "constructs"; import * as apigwv2 from "aws-cdk-lib/aws-apigatewayv2"; import { WebSocketLambdaIntegration } from "aws-cdk-lib/aws-apigatewayv2-integrations"; import { DockerImageCode, DockerImageFunction, IFunction, Runtime, } from "aws-cdk-lib/aws-lambda"; import { NodejsFunction } from "aws-cdk-lib/aws-lambda-nodejs"; import * as path from "path"; import * as iam from "aws-cdk-lib/aws-iam"; import { CfnOutput, Duration, RemovalPolicy, Stack } from "aws-cdk-lib"; import { Platform } from "aws-cdk-lib/aws-ecr-assets"; import { Auth } from "./auth"; import { ITable } from "aws-cdk-lib/aws-dynamodb"; import { CfnRouteResponse } from "aws-cdk-lib/aws-apigatewayv2"; import { ISecret } from "aws-cdk-lib/aws-secretsmanager"; import * as ec2 from "aws-cdk-lib/aws-ec2"; import * as s3 from "aws-cdk-lib/aws-s3"; export interface WebSocketProps { readonly vpc: ec2.IVpc; readonly database: ITable; readonly dbSecrets: ISecret; readonly auth: Auth; readonly bedrockRegion: string; readonly tableAccessRole: iam.IRole; readonly websocketSessionTable: ITable; readonly largeMessageBucket: s3.IBucket; readonly accessLogBucket?: s3.Bucket; } export class WebSocket extends Construct { readonly webSocketApi: apigwv2.IWebSocketApi; readonly handler: IFunction; private readonly defaultStageName = "dev"; constructor(scope: Construct, id: string, props: WebSocketProps) { super(scope, id); const { database, tableAccessRole } = props; // Bucket for SNS large payload support // See: https://docs.aws.amazon.com/sns/latest/dg/extended-client-library-python.html const largePayloadSupportBucket = new s3.Bucket( this, "LargePayloadSupportBucket", { encryption: s3.BucketEncryption.S3_MANAGED, blockPublicAccess: s3.BlockPublicAccess.BLOCK_ALL, enforceSSL: true, removalPolicy: RemovalPolicy.DESTROY, objectOwnership: s3.ObjectOwnership.OBJECT_WRITER, autoDeleteObjects: true, serverAccessLogsBucket: props.accessLogBucket, serverAccessLogsPrefix: "LargePayloadSupportBucket", } ); const handlerRole = new iam.Role(this, "HandlerRole", { assumedBy: new iam.ServicePrincipal("lambda.amazonaws.com"), }); handlerRole.addToPolicy( // Assume the table access role for row-level access control. new iam.PolicyStatement({ actions: ["sts:AssumeRole"], resources: [tableAccessRole.roleArn], }) ); handlerRole.addToPolicy( new iam.PolicyStatement({ actions: ["bedrock:*"], resources: ["*"], }) ); handlerRole.addManagedPolicy( iam.ManagedPolicy.fromAwsManagedPolicyName( "service-role/AWSLambdaVPCAccessExecutionRole" ) ); largePayloadSupportBucket.grantRead(handlerRole); props.websocketSessionTable.grantReadWriteData(handlerRole); props.largeMessageBucket.grantReadWrite(handlerRole); const handler = new DockerImageFunction(this, "Handler", { code: DockerImageCode.fromImageAsset( path.join(__dirname, "../../../backend"), { platform: Platform.LINUX_AMD64, file: "websocket.Dockerfile", } ), vpc: props.vpc, vpcSubnets: { subnetType: ec2.SubnetType.PRIVATE_WITH_EGRESS }, memorySize: 512, timeout: Duration.minutes(15), environment: { ACCOUNT: Stack.of(this).account, REGION: Stack.of(this).region, USER_POOL_ID: props.auth.userPool.userPoolId, CLIENT_ID: props.auth.client.userPoolClientId, BEDROCK_REGION: props.bedrockRegion, TABLE_NAME: database.tableName, TABLE_ACCESS_ROLE_ARN: tableAccessRole.roleArn, LARGE_MESSAGE_BUCKET: props.largeMessageBucket.bucketName, DB_SECRETS_ARN: props.dbSecrets.secretArn, LARGE_PAYLOAD_SUPPORT_BUCKET: largePayloadSupportBucket.bucketName, WEBSOCKET_SESSION_TABLE_NAME: props.websocketSessionTable.tableName, }, role: handlerRole, }); props.dbSecrets.grantRead(handler); const webSocketApi = new apigwv2.WebSocketApi(this, "WebSocketApi", { connectRouteOptions: { integration: new WebSocketLambdaIntegration( "ConnectIntegration", handler ), }, }); const route = webSocketApi.addRoute("$default", { integration: new WebSocketLambdaIntegration( "DefaultIntegration", handler ), }); new apigwv2.WebSocketStage(this, "WebSocketStage", { webSocketApi, stageName: this.defaultStageName, autoDeploy: true, }); webSocketApi.grantManageConnections(handler); new CfnRouteResponse(this, "RouteResponse", { apiId: webSocketApi.apiId, routeId: route.routeId, routeResponseKey: "$default", }); this.webSocketApi = webSocketApi; this.handler = handler; new CfnOutput(this, "WebSocketEndpoint", { value: this.apiEndpoint, }); } get apiEndpoint() { return `${this.webSocketApi.apiEndpoint}/${this.defaultStageName}`; } }