accessors/dataflow/dataflow_accessor.go (145 lines of code) (raw):

// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package dataflowaccessor import ( "context" "fmt" "sort" "strings" "cloud.google.com/go/dataflow/apiv1beta3/dataflowpb" dataflowclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/dataflow" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" "golang.org/x/exp/maps" ) // The DataflowAccessor provides methods that internally use the dataflow client. Methods should only contain generic logic here that can be used by multiple workflows. type DataflowAccessor interface { // This function takes the template parameters (@parameters) and runtime environment config (@cfg) as input, and returns // the generated jobId, equivalentGcloudCommand and error if any. LaunchDataflowTemplate(ctx context.Context, c dataflowclient.DataflowClient, parameters map[string]string, cfg DataflowTuningConfig) (string, string, error) } // This implements the DataflowAccessor interface. This is the primary implementation that should be used in all places other than tests. type DataflowAccessorImpl struct{} func (dfA *DataflowAccessorImpl) LaunchDataflowTemplate(ctx context.Context, c dataflowclient.DataflowClient, parameters map[string]string, cfg DataflowTuningConfig) (string, string, error) { req, err := getDataflowLaunchRequest(parameters, cfg) if err != nil { return "", "", err } respDf, err := c.LaunchFlexTemplate(ctx, req) if err != nil { logger.Log.Error(fmt.Sprintf("flexTemplateRequest: %+v\n", req)) return "", "", fmt.Errorf("error launching dataflow template: %v", err) } gCloudCmd := GetGcloudDataflowCommandFromRequest(req) return respDf.Job.Id, gCloudCmd, nil } func getDataflowLaunchRequest(parameters map[string]string, cfg DataflowTuningConfig) (*dataflowpb.LaunchFlexTemplateRequest, error) { // If custom network is not selected, use public IP. Typical for internal testing flow. vpcSubnetwork := "" workerIpAddressConfig := dataflowpb.WorkerIPAddressConfiguration_WORKER_IP_PUBLIC if cfg.Network != "" || cfg.Subnetwork != "" { workerIpAddressConfig = dataflowpb.WorkerIPAddressConfiguration_WORKER_IP_PRIVATE // If subnetwork is not provided, assume network has auto subnet configuration. if cfg.Subnetwork != "" { if cfg.VpcHostProjectId == "" || cfg.Location == "" { return nil, fmt.Errorf("vpc host project id and location must be specified when specifying subnetwork") } vpcSubnetwork = fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/%s/regions/%s/subnetworks/%s", cfg.VpcHostProjectId, cfg.Location, cfg.Subnetwork) } } // Dataflow does not accept upper case letters in the name. cfg.JobName = strings.ToLower(cfg.JobName) request := &dataflowpb.LaunchFlexTemplateRequest{ ProjectId: cfg.ProjectId, LaunchParameter: &dataflowpb.LaunchFlexTemplateParameter{ JobName: cfg.JobName, Template: &dataflowpb.LaunchFlexTemplateParameter_ContainerSpecGcsPath{ContainerSpecGcsPath: cfg.GcsTemplatePath}, Parameters: parameters, Environment: &dataflowpb.FlexTemplateRuntimeEnvironment{ MaxWorkers: cfg.MaxWorkers, NumWorkers: cfg.NumWorkers, ServiceAccountEmail: cfg.ServiceAccountEmail, MachineType: cfg.MachineType, AdditionalUserLabels: cfg.AdditionalUserLabels, KmsKeyName: cfg.KmsKeyName, Network: cfg.Network, Subnetwork: vpcSubnetwork, IpConfiguration: workerIpAddressConfig, AdditionalExperiments: cfg.AdditionalExperiments, EnableStreamingEngine: cfg.EnableStreamingEngine, }, }, Location: cfg.Location, } logger.Log.Debug(fmt.Sprintf("Flex Template request generated: %+v", request)) return request, nil } // Generate the equivalent gCloud CLI command to launch a dataflow job with the same parameters and environment flags // as the input body. func GetGcloudDataflowCommandFromRequest(req *dataflowpb.LaunchFlexTemplateRequest) string { lp := req.LaunchParameter templatePath := lp.Template.(*dataflowpb.LaunchFlexTemplateParameter_ContainerSpecGcsPath).ContainerSpecGcsPath cmd := fmt.Sprintf("gcloud dataflow flex-template run %s --project=%s --region=%s --template-file-gcs-location=%s %s %s", lp.JobName, req.ProjectId, req.Location, templatePath, getEnvironmentFlags(lp.Environment), getParametersFlag(lp.Parameters)) return strings.Trim(cmd, " ") } // Generate the equivalent parameter flag string, returning empty string if none are specified. func getParametersFlag(parameters map[string]string) string { if len(parameters) == 0 { return "" } params := "" keys := maps.Keys(parameters) sort.Strings(keys) for _, k := range keys { params = params + k + "=" + parameters[k] + "," } params = strings.TrimSuffix(params, ",") return fmt.Sprintf("--parameters %s", params) } // We don't populate all flags in the API because certain flags (like AutoscalingAlgorithm, DumpHeapOnOom etc.) // are not supported in gCloud. func getEnvironmentFlags(environment *dataflowpb.FlexTemplateRuntimeEnvironment) string { flag := "" if environment.NumWorkers != 0 { flag += fmt.Sprintf("--num-workers %d ", environment.NumWorkers) } if environment.MaxWorkers != 0 { flag += fmt.Sprintf("--max-workers %d ", environment.MaxWorkers) } if environment.ServiceAccountEmail != "" { flag += fmt.Sprintf("--service-account-email %s ", environment.ServiceAccountEmail) } if environment.TempLocation != "" { flag += fmt.Sprintf("--temp-location %s ", environment.TempLocation) } if environment.MachineType != "" { flag += fmt.Sprintf("--worker-machine-type %s ", environment.MachineType) } if environment.AdditionalExperiments != nil && len(environment.AdditionalExperiments) > 0 { flag += fmt.Sprintf("--additional-experiments %s ", strings.Join(environment.AdditionalExperiments, ",")) } if environment.Network != "" { flag += fmt.Sprintf("--network %s ", environment.Network) } if environment.Subnetwork != "" { flag += fmt.Sprintf("--subnetwork %s ", environment.Subnetwork) } if environment.AdditionalUserLabels != nil && len(environment.AdditionalUserLabels) > 0 { flag += fmt.Sprintf("--additional-user-labels %s ", formatAdditionalUserLabels(environment.AdditionalUserLabels)) } if environment.KmsKeyName != "" { flag += fmt.Sprintf("--dataflow-kms-key %s ", environment.KmsKeyName) } if environment.IpConfiguration == dataflowpb.WorkerIPAddressConfiguration_WORKER_IP_PRIVATE { flag += "--disable-public-ips " } if environment.WorkerRegion != "" { flag += fmt.Sprintf("--worker-region %s ", environment.WorkerRegion) } if environment.WorkerZone != "" { flag += fmt.Sprintf("--worker-zone %s ", environment.WorkerZone) } if environment.EnableStreamingEngine { flag += "--enable-streaming-engine " } if environment.FlexrsGoal != dataflowpb.FlexResourceSchedulingGoal_FLEXRS_UNSPECIFIED { flag += fmt.Sprintf("--flexrs-goal %s ", environment.FlexrsGoal) } if environment.StagingLocation != "" { flag += fmt.Sprintf("--staging-location %s ", environment.StagingLocation) } return strings.Trim(flag, " ") } func formatAdditionalUserLabels(labels map[string]string) string { res := []string{} for key, value := range labels { res = append(res, fmt.Sprintf("%s=%s", key, value)) } return strings.Join(res, ",") }