backend-apis/deployment_scripts/vertex_vector_operations.py (208 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Create Vector Search Index """ import argparse from google.cloud import aiplatform_v1beta1 as aiplatform from google.protobuf import struct_pb2 from google.protobuf.json_format import ParseDict # pylint: disable-next = line-too-long METADATA_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/matchingengine/metadata/nearest_neighbor_search_1.0.0.yaml" INDEX_UPDATE_METHOD = aiplatform.Index.IndexUpdateMethod(2) def create_vector_index( project_id: str, location: str, display_name: str, description: str, metadata: struct_pb2.Value, ): """ Create vector index Args: project_id: Project id location: Index location display_name: Display name description: Index description metadata: Index Metadata Returns: Creation response """ index_client = aiplatform.IndexServiceClient( client_options={ "api_endpoint": "us-central1-aiplatform.googleapis.com" } ) index = aiplatform.Index() index.display_name = display_name index.description = description index.metadata_schema_uri = METADATA_SCHEMA_URI index.metadata = metadata index.index_update_method = INDEX_UPDATE_METHOD request = aiplatform.CreateIndexRequest( parent=f"projects/{project_id}/locations/{location}", index=index ) operation = index_client.create_index(request=request) response = operation.result(timeout=None) return response def create_index_endpoint( project_id: str, location: str, display_name: str, description: str, public_endpoint_enabled: bool, ): """ Create index endpoint Args: project_id: Project id location: Index endpoint location display_name: Index endpoint display name description: Index endpoint description public_endpoint_enabled: Whether to enable public endpoint Returns: Creation response """ index_endpoint_client = aiplatform.IndexEndpointServiceClient( client_options={ "api_endpoint": "us-central1-aiplatform.googleapis.com" } ) index_endpoint = aiplatform.IndexEndpoint() index_endpoint.display_name = display_name index_endpoint.description = description index_endpoint.public_endpoint_enabled = public_endpoint_enabled request = aiplatform.CreateIndexEndpointRequest( parent=f"projects/{project_id}/locations/{location}", index_endpoint=index_endpoint, ) operation = index_endpoint_client.create_index_endpoint(request=request) response = operation.result(timeout=None) return response def deploy_index_to_endpoint( deploy_id: str, index: str, display_name: str, index_endpoint: str, ): """ Deploy index to endpoint Args: deploy_id: Deploy id index: Index display_name: Display name index_endpoint: Index endpoint Returns: Deployment response """ index_endpoint_client = aiplatform.IndexEndpointServiceClient( client_options={ "api_endpoint": "us-central1-aiplatform.googleapis.com" } ) deploy = aiplatform.DeployedIndex() deploy.id = deploy_id deploy.index = index deploy.display_name = display_name request = aiplatform.DeployIndexRequest( index_endpoint=index_endpoint, deployed_index=deploy ) operation = index_endpoint_client.deploy_index(request=request) response = operation.result(timeout=None) return response def get_index_resource_name( project_id: str, location: str, index_display_name: str ) -> str: """ Get index resource name Args: project_id: Project id location: Index location index_display_name: Index display name Returns: Resource name """ client = aiplatform.IndexServiceClient( client_options={ "api_endpoint": "us-central1-aiplatform.googleapis.com" } ) request = aiplatform.ListIndexesRequest( parent=f"projects/{project_id}/locations/{location}" ) results = client.list_indexes(request=request) resource_name = "" for i in list(results): if i.display_name == index_display_name: resource_name = i.name break return resource_name def get_endpoint_info( project_id: str, location: str, endpoint_display_name: str ) -> tuple: """ Get endpoint info Args: project_id: Project id location: Endpoint location endpoint_display_name: Endpoint display name Returns: Tuple with Endpoint name, Deployed index id and Public Endpoint domain name """ client = aiplatform.IndexEndpointServiceClient( client_options={ "api_endpoint": "us-central1-aiplatform.googleapis.com" } ) request = aiplatform.ListIndexEndpointsRequest( parent=f"projects/{project_id}/locations/{location}", ) results = client.list_index_endpoints(request=request) endpoint_name = "" deployed_index_id = "" public_endpoint_domain_name = "" for i in list(results): if i.display_name == endpoint_display_name: endpoint_name = i.name deployed_index_id = i.deployed_indexes[0].id public_endpoint_domain_name = i.public_endpoint_domain_name return endpoint_name, deployed_index_id, public_endpoint_domain_name def main(args): """ Main creation function Args: args: Command line args """ metadata = { "contentsDeltaUri": args.contents_delta_uri, "config": { "dimensions": 1408, "approximateNeighborsCount": 150, "distanceMeasureType": "DOT_PRODUCT_DISTANCE", "featureNormType": "UNIT_L2_NORM", "algorithmConfig": { "treeAhConfig": { "leafNodeEmbeddingCount": 1000, "fractionLeafNodesToSearch": 0.05, } }, }, } struct = struct_pb2.Struct() ParseDict(metadata, struct) schema_value = struct_pb2.Value(struct_value=struct) print("Creating vector index") create_vector_index( project_id=args.project_id, location=args.location, display_name=args.index_display_name, description=args.index_description, metadata=schema_value, ) print("Creating index endpoint") create_index_endpoint( project_id=args.project_id, location=args.location, display_name=args.endpoint_display_name, description=args.endpoint_description, public_endpoint_enabled=True, ) index_resource_name = get_index_resource_name( project_id=args.project_id, location=args.location, index_display_name=args.index_display_name, ) ( endpoint_name, deployed_index_id, public_endpoint_domain_name, ) = get_endpoint_info( project_id=args.project_id, location=args.location, endpoint_display_name=args.endpoint_display_name, ) print("Deploying index to endpoint") deploy_index_to_endpoint( deploy_id=args.deploy_id, index=index_resource_name, display_name=args.deploy_display_name, index_endpoint=endpoint_name, ) print("Index endpoint id:") print(endpoint_name.split(sep="/")[-1]) print("Deployed index id:") print(deployed_index_id) print("Vector API endpoint") print(public_endpoint_domain_name) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--project_id", required=True) parser.add_argument("--location", required=True) parser.add_argument( "--index_display_name", default="csm-multimodal-vector-search", required=False, ) parser.add_argument( "--index_description", default="CSM Multimodal Vector Search", required=False, ) parser.add_argument( "--endpoint_display_name", default="csm-index-endpoint", required=False ) parser.add_argument( "--endpoint_description", default="CSM Index Endpoint", required=False ) parser.add_argument( "--deploy_display_name", default="csm_deployed_index", required=False ) parser.add_argument( "--deploy_id", default="csm_deployed_index", required=False ) parser.add_argument( "--contents_delta_uri", default="gs://csm-solution-dataset/metadata/vertex-vector-search", required=False, ) parsed_args = parser.parse_args() main(parsed_args)