pkg/deploy/lattice/target_group_manager.go (428 lines of code) (raw):

package lattice import ( "context" "errors" "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/vpclattice" "github.com/aws/aws-application-networking-k8s/pkg/aws/services" "github.com/aws/aws-application-networking-k8s/pkg/model/core" "github.com/aws/aws-application-networking-k8s/pkg/utils" "reflect" pkg_aws "github.com/aws/aws-application-networking-k8s/pkg/aws" model "github.com/aws/aws-application-networking-k8s/pkg/model/lattice" "github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog" ) //go:generate mockgen -destination target_group_manager_mock.go -package lattice github.com/aws/aws-application-networking-k8s/pkg/deploy/lattice TargetGroupManager type TargetGroupManager interface { Upsert(ctx context.Context, modelTg *model.TargetGroup) (model.TargetGroupStatus, error) Delete(ctx context.Context, modelTg *model.TargetGroup) error List(ctx context.Context) ([]tgListOutput, error) IsTargetGroupMatch(ctx context.Context, modelTg *model.TargetGroup, latticeTg *vpclattice.TargetGroupSummary, latticeTags *model.TargetGroupTagFields) (bool, error) ResolveRuleTgIds(ctx context.Context, modelRuleAction *model.RuleAction, stack core.Stack) error } type defaultTargetGroupManager struct { log gwlog.Logger cloud pkg_aws.Cloud } func NewTargetGroupManager(log gwlog.Logger, cloud pkg_aws.Cloud) *defaultTargetGroupManager { return &defaultTargetGroupManager{ log: log, cloud: cloud, } } func (s *defaultTargetGroupManager) Upsert( ctx context.Context, modelTg *model.TargetGroup, ) (model.TargetGroupStatus, error) { // check if exists latticeTgSummary, err := s.findTargetGroup(ctx, modelTg) if err != nil { return model.TargetGroupStatus{}, err } if latticeTgSummary == nil { return s.create(ctx, modelTg) } else { return s.update(ctx, modelTg, latticeTgSummary) } } func (s *defaultTargetGroupManager) create(ctx context.Context, modelTg *model.TargetGroup) (model.TargetGroupStatus, error) { var ipAddressType, protocolVersion *string if modelTg.Spec.IpAddressType != "" { ipAddressType = &modelTg.Spec.IpAddressType } if modelTg.Spec.ProtocolVersion == "" { protocolVersion = nil } else { protocolVersion = &modelTg.Spec.ProtocolVersion } latticeTgCfg := &vpclattice.TargetGroupConfig{ Port: aws.Int64(int64(modelTg.Spec.Port)), Protocol: &modelTg.Spec.Protocol, ProtocolVersion: protocolVersion, VpcIdentifier: &modelTg.Spec.VpcId, IpAddressType: ipAddressType, HealthCheck: modelTg.Spec.HealthCheckConfig, } latticeTgType := string(modelTg.Spec.Type) latticeTgName := model.GenerateTgName(modelTg.Spec) createInput := vpclattice.CreateTargetGroupInput{ Config: latticeTgCfg, Name: &latticeTgName, Type: &latticeTgType, Tags: s.cloud.DefaultTags(), } createInput.Tags[model.K8SClusterNameKey] = &modelTg.Spec.K8SClusterName createInput.Tags[model.K8SServiceNameKey] = &modelTg.Spec.K8SServiceName createInput.Tags[model.K8SServiceNamespaceKey] = &modelTg.Spec.K8SServiceNamespace createInput.Tags[model.K8SSourceTypeKey] = aws.String(string(modelTg.Spec.K8SSourceType)) createInput.Tags[model.K8SProtocolVersionKey] = &modelTg.Spec.ProtocolVersion if modelTg.Spec.IsSourceTypeRoute() { createInput.Tags[model.K8SRouteNameKey] = &modelTg.Spec.K8SRouteName createInput.Tags[model.K8SRouteNamespaceKey] = &modelTg.Spec.K8SRouteNamespace } lattice := s.cloud.Lattice() resp, err := lattice.CreateTargetGroupWithContext(ctx, &createInput) if err != nil { return model.TargetGroupStatus{}, fmt.Errorf("failed CreateTargetGroup %s due to %s", latticeTgName, err) } s.log.Infof(ctx, "Success CreateTargetGroup %s", latticeTgName) latticeTgStatus := aws.StringValue(resp.Status) if latticeTgStatus != vpclattice.TargetGroupStatusActive && latticeTgStatus != vpclattice.TargetGroupStatusCreateInProgress { s.log.Infof(ctx, "Target group is not in the desired state. State is %s, will retry", latticeTgStatus) return model.TargetGroupStatus{}, errors.New(LATTICE_RETRY) } // create-in-progress is considered success // later, target reg may need to retry due to the state, and that's OK return model.TargetGroupStatus{ Name: aws.StringValue(resp.Name), Arn: aws.StringValue(resp.Arn), Id: aws.StringValue(resp.Id)}, nil } func (s *defaultTargetGroupManager) update(ctx context.Context, targetGroup *model.TargetGroup, latticeTg *vpclattice.GetTargetGroupOutput) (model.TargetGroupStatus, error) { healthCheckConfig := targetGroup.Spec.HealthCheckConfig if healthCheckConfig == nil { s.log.Debugf(ctx, "HealthCheck is empty. Resetting to default settings") healthCheckConfig = &vpclattice.HealthCheckConfig{} } s.fillDefaultHealthCheckConfig(healthCheckConfig, targetGroup.Spec.Protocol, targetGroup.Spec.ProtocolVersion) if !reflect.DeepEqual(healthCheckConfig, latticeTg.Config.HealthCheck) { _, err := s.cloud.Lattice().UpdateTargetGroupWithContext(ctx, &vpclattice.UpdateTargetGroupInput{ HealthCheck: healthCheckConfig, TargetGroupIdentifier: latticeTg.Id, }) if err != nil { return model.TargetGroupStatus{}, fmt.Errorf("failed UpdateTargetGroup %s due to %w", aws.StringValue(latticeTg.Id), err) } } modelTgStatus := model.TargetGroupStatus{ Name: aws.StringValue(latticeTg.Name), Arn: aws.StringValue(latticeTg.Arn), Id: aws.StringValue(latticeTg.Id), } return modelTgStatus, nil } func (s *defaultTargetGroupManager) Delete(ctx context.Context, modelTg *model.TargetGroup) error { if modelTg.Status == nil || modelTg.Status.Id == "" { latticeTgSummary, err := s.findTargetGroup(ctx, modelTg) if err != nil { return err } if latticeTgSummary == nil { // nothing to delete s.log.Infof(ctx, "Target group with name prefix %s does not exist, nothing to delete", model.TgNamePrefix(modelTg.Spec)) return nil } modelTg.Status = &model.TargetGroupStatus{ Name: aws.StringValue(latticeTgSummary.Name), Arn: aws.StringValue(latticeTgSummary.Arn), Id: aws.StringValue(latticeTgSummary.Id), } } s.log.Debugf(ctx, "Deleting target group %s", modelTg.Status.Id) lattice := s.cloud.Lattice() // de-register all targets first listTargetsInput := vpclattice.ListTargetsInput{ TargetGroupIdentifier: &modelTg.Status.Id, } listResp, err := lattice.ListTargetsAsList(ctx, &listTargetsInput) if err != nil { if services.IsLatticeAPINotFoundErr(err) { s.log.Debugf(ctx, "Target group %s was already deleted", modelTg.Status.Id) return nil } return fmt.Errorf("failed ListTargets %s due to %s", modelTg.Status.Id, err) } var targetsToDeregister []*vpclattice.Target drainCount := 0 for _, t := range listResp { targetsToDeregister = append(targetsToDeregister, &vpclattice.Target{ Id: t.Id, Port: t.Port, }) if aws.StringValue(t.Status) == vpclattice.TargetStatusDraining { drainCount++ } } if drainCount > 0 { // no point in trying to deregister may as well wait return fmt.Errorf("cannot deregister targets for %s as %d targets are DRAINING", modelTg.Status.Id, drainCount) } if len(targetsToDeregister) > 0 { var deregisterTargetsError error chunks := utils.Chunks(targetsToDeregister, maxTargetsPerLatticeTargetsApiCall) for i, targets := range chunks { deregisterInput := vpclattice.DeregisterTargetsInput{ TargetGroupIdentifier: &modelTg.Status.Id, Targets: targets, } deregisterResponse, err := lattice.DeregisterTargetsWithContext(ctx, &deregisterInput) if err != nil { deregisterTargetsError = errors.Join(deregisterTargetsError, fmt.Errorf("failed to deregister targets from VPC Lattice Target Group %s due to %s", modelTg.Status.Id, err)) } if len(deregisterResponse.Unsuccessful) > 0 { deregisterTargetsError = errors.Join(deregisterTargetsError, fmt.Errorf("failed to deregister targets from VPC Lattice Target Group %s for chunk %d/%d, unsuccessful targets %v", modelTg.Status.Id, i+1, len(chunks), deregisterResponse.Unsuccessful)) } s.log.Debugf(ctx, "Successfully deregistered targets from VPC Lattice Target Group %s for chunk %d/%d", modelTg.Status.Id, i+1, len(chunks)) } if deregisterTargetsError != nil { return deregisterTargetsError } } deleteTGInput := vpclattice.DeleteTargetGroupInput{ TargetGroupIdentifier: &modelTg.Status.Id, } _, err = lattice.DeleteTargetGroupWithContext(ctx, &deleteTGInput) if err != nil { if services.IsLatticeAPINotFoundErr(err) { s.log.Infof(ctx, "Target group %s was already deleted", modelTg.Status.Id) return nil } else { return fmt.Errorf("failed DeleteTargetGroup %s due to %s", modelTg.Status.Id, err) } } s.log.Infof(ctx, "Success DeleteTargetGroup %s", modelTg.Status.Id) return nil } type tgListOutput struct { tgSummary *vpclattice.TargetGroupSummary tags services.Tags } // Retrieve all TGs in the account, including tags. If individual tags fetch fails, tags will be nil for that tg func (s *defaultTargetGroupManager) List(ctx context.Context) ([]tgListOutput, error) { lattice := s.cloud.Lattice() var tgList []tgListOutput targetGroupListInput := vpclattice.ListTargetGroupsInput{} resp, err := lattice.ListTargetGroupsAsList(ctx, &targetGroupListInput) if err != nil { return nil, err } if len(resp) == 0 { return nil, nil } tgArns := utils.SliceMap(resp, func(tg *vpclattice.TargetGroupSummary) string { return aws.StringValue(tg.Arn) }) tgArnToTagsMap, err := s.cloud.Tagging().GetTagsForArns(ctx, tgArns) if err != nil { return nil, err } for _, tg := range resp { tgList = append(tgList, tgListOutput{ tgSummary: tg, tags: tgArnToTagsMap[*tg.Arn], }) } return tgList, err } func (s *defaultTargetGroupManager) findTargetGroup( ctx context.Context, modelTargetGroup *model.TargetGroup, ) (*vpclattice.GetTargetGroupOutput, error) { arns, err := s.cloud.Tagging().FindResourcesByTags(ctx, services.ResourceTypeTargetGroup, model.TagsFromTGTagFields(modelTargetGroup.Spec.TargetGroupTagFields)) if err != nil { return nil, err } if len(arns) == 0 { return nil, nil } for _, arn := range arns { latticeTg, err := s.cloud.Lattice().GetTargetGroupWithContext(ctx, &vpclattice.GetTargetGroupInput{ TargetGroupIdentifier: &arn, }) if err != nil { if services.IsNotFoundError(err) { continue } return nil, err } // we ignore create failed status, so may as well check for it first status := aws.StringValue(latticeTg.Status) if status == vpclattice.TargetGroupStatusCreateFailed { continue } // Check the immutable fields to ensure TG is valid match, err := s.IsTargetGroupMatch(ctx, modelTargetGroup, &vpclattice.TargetGroupSummary{ Arn: latticeTg.Arn, Port: latticeTg.Config.Port, Protocol: latticeTg.Config.Protocol, IpAddressType: latticeTg.Config.IpAddressType, Type: latticeTg.Type, VpcIdentifier: latticeTg.Config.VpcIdentifier, }, nil) // we already know that tags match if err != nil { return nil, err } if match { switch status { case vpclattice.TargetGroupStatusCreateInProgress, vpclattice.TargetGroupStatusDeleteInProgress: return nil, errors.New(LATTICE_RETRY) case vpclattice.TargetGroupStatusDeleteFailed, vpclattice.TargetGroupStatusActive: return latticeTg, nil } } } return nil, nil } // Skips tag verification if not provided func (s *defaultTargetGroupManager) IsTargetGroupMatch(ctx context.Context, modelTg *model.TargetGroup, latticeTg *vpclattice.TargetGroupSummary, latticeTagsAsModelTags *model.TargetGroupTagFields) (bool, error) { if aws.Int64Value(latticeTg.Port) != int64(modelTg.Spec.Port) || aws.StringValue(latticeTg.Protocol) != modelTg.Spec.Protocol || aws.StringValue(latticeTg.IpAddressType) != modelTg.Spec.IpAddressType || aws.StringValue(latticeTg.Type) != string(modelTg.Spec.Type) || aws.StringValue(latticeTg.VpcIdentifier) != modelTg.Spec.VpcId { return false, nil } if latticeTagsAsModelTags != nil { tagsMatch := model.TagFieldsMatch(modelTg.Spec, *latticeTagsAsModelTags) if !tagsMatch { return false, nil } } return true, nil } // Get default health check configuration according to // https://docs.aws.amazon.com/vpc-lattice/latest/ug/target-group-health-checks.html#health-check-settings func (s *defaultTargetGroupManager) getDefaultHealthCheckConfig(targetGroupProtocol string, targetGroupProtocolVersion string) *vpclattice.HealthCheckConfig { if targetGroupProtocol == vpclattice.TargetGroupProtocolTcp { return &vpclattice.HealthCheckConfig{ Enabled: aws.Bool(false), } } var ( defaultHealthCheckIntervalSeconds int64 = 30 defaultHealthCheckTimeoutSeconds int64 = 5 defaultHealthyThresholdCount int64 = 5 defaultUnhealthyThresholdCount int64 = 2 defaultMatcher = vpclattice.Matcher{ HttpCode: aws.String("200"), } defaultPath = "/" defaultProtocol = vpclattice.TargetGroupProtocolHttp ) if targetGroupProtocolVersion == "" { targetGroupProtocolVersion = vpclattice.TargetGroupProtocolVersionHttp1 } enabled := targetGroupProtocolVersion == vpclattice.TargetGroupProtocolVersionHttp1 healthCheckProtocolVersion := targetGroupProtocolVersion if targetGroupProtocolVersion == vpclattice.TargetGroupProtocolVersionGrpc { healthCheckProtocolVersion = vpclattice.HealthCheckProtocolVersionHttp1 } return &vpclattice.HealthCheckConfig{ Enabled: &enabled, Protocol: &defaultProtocol, ProtocolVersion: &healthCheckProtocolVersion, Path: &defaultPath, Matcher: &defaultMatcher, Port: nil, // Use target port HealthyThresholdCount: &defaultHealthyThresholdCount, UnhealthyThresholdCount: &defaultUnhealthyThresholdCount, HealthCheckTimeoutSeconds: &defaultHealthCheckTimeoutSeconds, HealthCheckIntervalSeconds: &defaultHealthCheckIntervalSeconds, } } func (s *defaultTargetGroupManager) fillDefaultHealthCheckConfig(hc *vpclattice.HealthCheckConfig, targetGroupProtocol string, targetGroupProtocolVersion string) { defaultCfg := s.getDefaultHealthCheckConfig(targetGroupProtocol, targetGroupProtocolVersion) if hc.Enabled == nil { hc.Enabled = defaultCfg.Enabled } if hc.Protocol == nil { hc.Protocol = defaultCfg.Protocol } if hc.ProtocolVersion == nil { hc.ProtocolVersion = defaultCfg.ProtocolVersion } if hc.Path == nil { hc.Path = defaultCfg.Path } if hc.Matcher == nil { hc.Matcher = defaultCfg.Matcher } if hc.HealthCheckTimeoutSeconds == nil { hc.HealthCheckTimeoutSeconds = defaultCfg.HealthCheckTimeoutSeconds } if hc.HealthCheckIntervalSeconds == nil { hc.HealthCheckIntervalSeconds = defaultCfg.HealthCheckIntervalSeconds } if hc.HealthyThresholdCount == nil { hc.HealthyThresholdCount = defaultCfg.HealthyThresholdCount } if hc.UnhealthyThresholdCount == nil { hc.UnhealthyThresholdCount = defaultCfg.UnhealthyThresholdCount } } func (s *defaultTargetGroupManager) findSvcExportTG(ctx context.Context, svcImportTg model.SvcImportTargetGroup) (string, error) { tgs, err := s.List(ctx) if err != nil { return "", err } for _, tg := range tgs { tgTags := model.TGTagFieldsFromTags(tg.tags) svcMatch := tgTags.IsSourceTypeServiceExport() && (tgTags.K8SServiceName == svcImportTg.K8SServiceName) && (tgTags.K8SServiceNamespace == svcImportTg.K8SServiceNamespace) clusterMatch := (svcImportTg.K8SClusterName == "") || (tgTags.K8SClusterName == svcImportTg.K8SClusterName) vpcMatch := (svcImportTg.VpcId == "") || (svcImportTg.VpcId == aws.StringValue(tg.tgSummary.VpcIdentifier)) if svcMatch && clusterMatch && vpcMatch { return *tg.tgSummary.Id, nil } } return "", errors.New("target group for service import could not be found") } // ResolveRuleTgIds populates all target group ids in the rule's actions func (s *defaultTargetGroupManager) ResolveRuleTgIds(ctx context.Context, ruleAction *model.RuleAction, stack core.Stack) error { if len(ruleAction.TargetGroups) == 0 { s.log.Debugf(ctx, "no target groups to resolve for rule") return nil } for i, ruleActionTg := range ruleAction.TargetGroups { if ruleActionTg.StackTargetGroupId == "" && ruleActionTg.SvcImportTG == nil && ruleActionTg.LatticeTgId == "" { return errors.New("rule TG is missing a required target group identifier") } if ruleActionTg.LatticeTgId != "" { s.log.Debugf(ctx, "Rule TG %d already resolved %s", i, ruleActionTg.LatticeTgId) continue } if ruleActionTg.StackTargetGroupId != "" { if ruleActionTg.StackTargetGroupId == model.InvalidBackendRefTgId { s.log.Debugf(ctx, "Rule TG has an invalid backendref, setting TG id to invalid") ruleActionTg.LatticeTgId = model.InvalidBackendRefTgId continue } s.log.Debugf(ctx, "Fetching TG %d from the stack (ID %s)", i, ruleActionTg.StackTargetGroupId) stackTg := &model.TargetGroup{} err := stack.GetResource(ruleActionTg.StackTargetGroupId, stackTg) if err != nil { return err } if stackTg.Status == nil { return errors.New("stack target group is missing Status field") } ruleActionTg.LatticeTgId = stackTg.Status.Id } if ruleActionTg.SvcImportTG != nil { s.log.Debugf(ctx, "Getting target group for service import %s %s (%s, %s)", ruleActionTg.SvcImportTG.K8SServiceName, ruleActionTg.SvcImportTG.K8SServiceNamespace, ruleActionTg.SvcImportTG.K8SClusterName, ruleActionTg.SvcImportTG.VpcId) tgId, err := s.findSvcExportTG(ctx, *ruleActionTg.SvcImportTG) if err != nil { return err } ruleActionTg.LatticeTgId = tgId } } return nil }