custom-targets/vertex-ai/model-deployer/vertexai.go (131 lines of code) (raw):
// Copyright 2023 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// https://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"context"
"fmt"
"google.golang.org/api/aiplatform/v1"
"google.golang.org/api/option"
"os"
"sigs.k8s.io/yaml"
"strings"
)
// deployModelFromManifest loads the file provided in `path` and returns the parsed DeployModelRequest
// from the data.
func deployModelFromManifest(path string) (*aiplatform.GoogleCloudAiplatformV1DeployModelRequest, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("error reading manifest file: %v", err)
}
deployModelRequest := &aiplatform.GoogleCloudAiplatformV1DeployModelRequest{}
if err = yaml.Unmarshal(data, deployModelRequest); err != nil {
return nil, fmt.Errorf("unable to parse deploy model deployModelRequest from manifest file: %v", err)
}
return deployModelRequest, nil
}
// fetchPreviousModel queries the provided Vertex AI endpoint to determine the model that was previously
// deployed.
func fetchPreviousModel(service *aiplatform.Service, endpointName, currentModel string) (string, error) {
endpoint, err := service.Projects.Locations.Endpoints.Get(endpointName).Do()
if err != nil {
return "", fmt.Errorf("unable to fetch endpoint: %v", err)
}
deployedModels := map[string]*aiplatform.GoogleCloudAiplatformV1DeployedModel{}
for _, dm := range endpoint.DeployedModels {
modelNameWithVersion := resolveDeployedModelNameWithVersion(dm)
deployedModels[modelNameWithVersion] = dm
}
delete(deployedModels, currentModel)
if len(deployedModels) != 1 {
return "", fmt.Errorf("unable to resolve previous deployed currentModel to canary against. Not including the current currentModel to be deployed, the endpoint has %d deployed models but expected only one", len(deployedModels))
}
var firstModel []*aiplatform.GoogleCloudAiplatformV1DeployedModel
for _, dm := range deployedModels {
firstModel = append(firstModel, dm)
}
return firstModel[0].Id, nil
}
// resolveDeployedModelNameWithVersion returns the model resource name associated with the provided DeployedModel
// with its version ID attached.
func resolveDeployedModelNameWithVersion(deployedModel *aiplatform.GoogleCloudAiplatformV1DeployedModel) string {
if strings.Contains(deployedModel.Model, "@") {
return deployedModel.Model
}
return fmt.Sprintf("%s@%s", deployedModel.Model, deployedModel.ModelVersionId)
}
// resolveModelWithVersion returns the model resource name its version ID attached.
func resolveModelWithVersion(model *aiplatform.GoogleCloudAiplatformV1Model) string {
if strings.Contains(model.Name, "@") {
return model.Name
}
return fmt.Sprintf("%s@%s", model.Name, model.VersionId)
}
// regionFromModel extracts the region from the model region name.
func regionFromModel(modelName string) (string, error) {
matches := modelRegex.FindStringSubmatch(modelName)
if len(matches) == 0 {
return "", fmt.Errorf("unable to parse model name")
}
return matches[2], nil
}
// extracts the region from the endpoint resource name.
func regionFromEndpoint(endpointName string) (string, error) {
matches := endpointRegex.FindStringSubmatch(endpointName)
if len(matches) == 0 {
return "", fmt.Errorf("unable to parse endpoint name")
}
return matches[2], nil
}
// newAIPlatformService generates a Service that can make API calls in the specified region.
func newAIPlatformService(ctx context.Context, region string) (*aiplatform.Service, error) {
endPointOption := option.WithEndpoint(fmt.Sprintf("%s-aiplatform.googleapis.com", region))
regionalService, err := aiplatform.NewService(ctx, endPointOption)
if err != nil {
return nil, fmt.Errorf("unable to authenticate")
}
return regionalService, nil
}
// fetchModel calls the aiplatform API to fetch the Vertex AI model using the given model name.
func fetchModel(service *aiplatform.Service, modelName string) (*aiplatform.GoogleCloudAiplatformV1Model, error) {
model, err := service.Projects.Locations.Models.Get(modelName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get model: %v", err)
}
return model, nil
}
// minReplicaCountFromConfig returns the minReplicaCount value from the provided configuration file.
func minReplicaCountFromConfig(deployedModel *aiplatform.GoogleCloudAiplatformV1DeployedModel) int64 {
if deployedModel.DedicatedResources != nil {
return deployedModel.DedicatedResources.MinReplicaCount
}
return 0
}
// deployModel performs the DeployModel request and awaits the resulting operation until it completes, it times out or an error occurs.
func deployModel(ctx context.Context, aiPlatformService *aiplatform.Service, endpoint string, request *aiplatform.GoogleCloudAiplatformV1DeployModelRequest) error {
op, err := aiPlatformService.Projects.Locations.Endpoints.DeployModel(endpoint, request).Do()
if err != nil {
return fmt.Errorf("unable to deploy model: %v", err)
}
return poll(ctx, aiPlatformService, op)
}
// undeployNoTrafficModels fetches the Vertex AI endpoint and und-deploys all the models that have no traffic routed to them.
func undeployNoTrafficModels(ctx context.Context, aiPlatformService *aiplatform.Service, endpointName string) error {
endpoint, err := aiPlatformService.Projects.Locations.Endpoints.Get(endpointName).Do()
if err != nil {
return fmt.Errorf("unable to fetch endpoint where model was deployed: %v", err)
}
var modelsToUndeploy = map[string]bool{}
for _, dm := range endpoint.DeployedModels {
modelsToUndeploy[dm.Id] = true
}
for id, split := range endpoint.TrafficSplit {
// model does not get un-deployed if its configured to receive traffic
if split != 0 {
delete(modelsToUndeploy, id)
}
}
undeployedCount := 0
err = nil
var lros []*aiplatform.GoogleLongrunningOperation
for id, _ := range modelsToUndeploy {
undeployRequest := &aiplatform.GoogleCloudAiplatformV1UndeployModelRequest{DeployedModelId: id}
lro, lroErr := aiPlatformService.Projects.Locations.Endpoints.UndeployModel(endpointName, undeployRequest).Do()
if err != nil {
fmt.Printf("error undeploying model: %v\n", err)
err = lroErr
undeployedCount += 1
} else {
lros = append(lros, lro)
}
}
for pollErr := range pollChan(ctx, aiPlatformService, lros...) {
if pollErr != nil {
fmt.Printf("Error in undeploy model operation: %v", err)
err = pollErr
}
}
return err
}