custom-targets/vertex-ai/model-deployer/addaliases.go (63 lines of code) (raw):

// Copyright 2023 Google LLC // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // https://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "context" "fmt" "cloud.google.com/go/storage" "github.com/GoogleCloudPlatform/cloud-deploy-samples/custom-targets/util/clouddeploy" "google.golang.org/api/aiplatform/v1" cdapi "google.golang.org/api/clouddeploy/v1" ) // aliasAssigner is responsible for applying model aliases during a post-deploy operation. type aliasAssigner struct { gcsClient *storage.Client request *addAliasesRequest } // process applies model aliases during a post-deploy operation. func (aa aliasAssigner) process(ctx context.Context) error { cdService, err := cdapi.NewService(ctx) if err != nil { return fmt.Errorf("unable to create cloud deploy API service: %v", err) } releaseName := fmt.Sprintf("projects/%s/locations/%s/deliveryPipelines/%s/releases/%s", aa.request.project, aa.request.location, aa.request.pipeline, aa.request.release) release, err := cdService.Projects.Locations.DeliveryPipelines.Releases.Get(releaseName).Do() if err != nil { return fmt.Errorf("unable to fetch release to determine location of rendered manifest: %v", err) } ta, ok := release.TargetArtifacts[aa.request.target] if !ok { return fmt.Errorf("target artifact does not exist in release") } pa, ok := ta.PhaseArtifacts[aa.request.phase] if !ok { return fmt.Errorf("target phase artifact not found in release") } manifestGcsPath := fmt.Sprintf("%s/%s", ta.ArtifactUri, pa.ManifestPath) localManifest := "manifest.yaml" fmt.Printf("Downloading deploy input manifest from %q.\n", manifestGcsPath) deployRequest := &clouddeploy.DeployRequest{ ManifestGCSPath: manifestGcsPath, } fmt.Printf("Downloading rendered manifest.\n") if _, err := deployRequest.DownloadManifest(ctx, aa.gcsClient, localManifest); err != nil { fmt.Println("Failed to download rendered manifest.") return fmt.Errorf("failed to download local manifest: %v", err) } deployedModelRequest, err := deployModelFromManifest(localManifest) if err != nil { return err } modelName := deployedModelRequest.DeployedModel.Model modelRegion, err := regionFromModel(modelName) if err != nil { return fmt.Errorf("unable to obtain region where deployed model is located: %v", err) } aiPlatformService, err := newAIPlatformService(ctx, modelRegion) if err != nil { return fmt.Errorf("unable to create aiplatform service: %v", err) } mergeVersionAliasRequest := &aiplatform.GoogleCloudAiplatformV1MergeVersionAliasesRequest{VersionAliases: aa.request.aliases} updatedModel, err := aiPlatformService.Projects.Locations.Models.MergeVersionAliases(modelName, mergeVersionAliasRequest).Do() if err != nil { return fmt.Errorf("unable to update model version aliases") } fmt.Printf("Successfully applied new aliases: %s. Current aliases are: %s\n", aa.request.aliases, updatedModel.VersionAliases) return nil }