pkg/providers/subnet/subnet.go (230 lines of code) (raw):
/*
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package subnet
import (
"context"
"fmt"
"net/http"
"sync"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/awslabs/operatorpkg/serrors"
"github.com/mitchellh/hashstructure/v2"
"github.com/patrickmn/go-cache"
"github.com/samber/lo"
corev1 "k8s.io/api/core/v1"
"sigs.k8s.io/controller-runtime/pkg/log"
v1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1"
karpv1 "sigs.k8s.io/karpenter/pkg/apis/v1"
"sigs.k8s.io/karpenter/pkg/cloudprovider"
"sigs.k8s.io/karpenter/pkg/scheduling"
"sigs.k8s.io/karpenter/pkg/utils/pretty"
sdk "github.com/aws/karpenter-provider-aws/pkg/aws"
)
type Provider interface {
LivenessProbe(*http.Request) error
List(context.Context, *v1.EC2NodeClass) ([]ec2types.Subnet, error)
ZonalSubnetsForLaunch(context.Context, *v1.EC2NodeClass, []*cloudprovider.InstanceType, string) (map[string]*Subnet, error)
UpdateInflightIPs(*ec2.CreateFleetInput, *ec2.CreateFleetOutput, []*cloudprovider.InstanceType, []*Subnet, string)
}
type DefaultProvider struct {
sync.Mutex
ec2api sdk.EC2API
cache *cache.Cache
availableIPAddressCache *cache.Cache
associatePublicIPAddressCache *cache.Cache
cm *pretty.ChangeMonitor
inflightIPs map[string]int32
}
type Subnet struct {
ID string
Zone string
ZoneID string
AvailableIPAddressCount int32
}
func NewDefaultProvider(ec2api sdk.EC2API, cache *cache.Cache, availableIPAddressCache *cache.Cache, associatePublicIPAddressCache *cache.Cache) *DefaultProvider {
return &DefaultProvider{
ec2api: ec2api,
cm: pretty.NewChangeMonitor(),
// TODO: Remove cache when we utilize the resolved subnets from the EC2NodeClass.status
// Subnets are sorted on AvailableIpAddressCount, descending order
cache: cache,
availableIPAddressCache: availableIPAddressCache,
associatePublicIPAddressCache: associatePublicIPAddressCache,
// inflightIPs is used to track IPs from known launched instances
inflightIPs: map[string]int32{},
}
}
func (p *DefaultProvider) List(ctx context.Context, nodeClass *v1.EC2NodeClass) ([]ec2types.Subnet, error) {
p.Lock()
defer p.Unlock()
filterSets := getFilterSets(nodeClass.Spec.SubnetSelectorTerms)
if len(filterSets) == 0 {
return []ec2types.Subnet{}, nil
}
hash, err := hashstructure.Hash(filterSets, hashstructure.FormatV2, &hashstructure.HashOptions{SlicesAsSets: true})
if err != nil {
return nil, err
}
if subnets, ok := p.cache.Get(fmt.Sprint(hash)); ok {
// Ensure what's returned from this function is a shallow-copy of the slice (not a deep-copy of the data itself)
// so that modifications to the ordering of the data don't affect the original
return append([]ec2types.Subnet{}, subnets.([]ec2types.Subnet)...), nil
}
// Ensure that all the subnets that are returned here are unique
subnets := map[string]ec2types.Subnet{}
for _, filters := range filterSets {
paginator := ec2.NewDescribeSubnetsPaginator(p.ec2api, &ec2.DescribeSubnetsInput{
Filters: filters,
MaxResults: lo.ToPtr(int32(500)),
})
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
return nil, serrors.Wrap(fmt.Errorf("describing subnets with filters, %w", err), "filters", pretty.Concise(filters))
}
for i := range output.Subnets {
subnets[lo.FromPtr(output.Subnets[i].SubnetId)] = output.Subnets[i]
p.availableIPAddressCache.SetDefault(lo.FromPtr(output.Subnets[i].SubnetId), lo.FromPtr(output.Subnets[i].AvailableIpAddressCount))
p.associatePublicIPAddressCache.SetDefault(lo.FromPtr(output.Subnets[i].SubnetId), lo.FromPtr(output.Subnets[i].MapPublicIpOnLaunch))
// subnets can be leaked here, if a subnets is never called received from ec2
// we are accepting it for now, as this will be an insignificant amount of memory
delete(p.inflightIPs, lo.FromPtr(output.Subnets[i].SubnetId)) // remove any previously tracked IP addresses since we just refreshed from EC2
}
}
}
p.cache.SetDefault(fmt.Sprint(hash), lo.Values(subnets))
if p.cm.HasChanged(fmt.Sprintf("subnets/%s", nodeClass.Name), lo.Keys(subnets)) {
log.FromContext(ctx).
WithValues("subnets", lo.Map(lo.Values(subnets), func(s ec2types.Subnet, _ int) v1.Subnet {
return v1.Subnet{
ID: lo.FromPtr(s.SubnetId),
Zone: lo.FromPtr(s.AvailabilityZone),
ZoneID: lo.FromPtr(s.AvailabilityZoneId),
}
})).V(1).Info("discovered subnets")
}
return lo.Values(subnets), nil
}
// ZonalSubnetsForLaunch returns a mapping of zone to the subnet with the most available IP addresses and deducts the passed ips from the available count
func (p *DefaultProvider) ZonalSubnetsForLaunch(ctx context.Context, nodeClass *v1.EC2NodeClass, instanceTypes []*cloudprovider.InstanceType, capacityType string) (map[string]*Subnet, error) {
if len(nodeClass.Status.Subnets) == 0 {
return nil, fmt.Errorf("no subnets matched selector %v", nodeClass.Spec.SubnetSelectorTerms)
}
p.Lock()
defer p.Unlock()
zonalSubnets := map[string]*Subnet{}
availableIPAddressCount := map[string]int32{}
for _, subnet := range nodeClass.Status.Subnets {
if subnetAvailableIP, ok := p.availableIPAddressCache.Get(subnet.ID); ok {
availableIPAddressCount[subnet.ID] = subnetAvailableIP.(int32)
}
}
for _, subnet := range nodeClass.Status.Subnets {
if v, ok := zonalSubnets[subnet.Zone]; ok {
currentZonalSubnetIPAddressCount := v.AvailableIPAddressCount
newZonalSubnetIPAddressCount := availableIPAddressCount[subnet.ID]
if ips, ok := p.inflightIPs[v.ID]; ok {
currentZonalSubnetIPAddressCount = ips
}
if ips, ok := p.inflightIPs[subnet.ID]; ok {
newZonalSubnetIPAddressCount = ips
}
if currentZonalSubnetIPAddressCount >= newZonalSubnetIPAddressCount {
continue
}
}
zonalSubnets[subnet.Zone] = &Subnet{ID: subnet.ID, Zone: subnet.Zone, ZoneID: subnet.ZoneID, AvailableIPAddressCount: availableIPAddressCount[subnet.ID]}
}
for _, subnet := range zonalSubnets {
predictedIPsUsed := p.minPods(instanceTypes, scheduling.NewRequirements(
scheduling.NewRequirement(karpv1.CapacityTypeLabelKey, corev1.NodeSelectorOpIn, capacityType),
scheduling.NewRequirement(corev1.LabelTopologyZone, corev1.NodeSelectorOpIn, subnet.Zone),
))
prevIPs := subnet.AvailableIPAddressCount
if trackedIPs, ok := p.inflightIPs[subnet.ID]; ok {
prevIPs = trackedIPs
}
p.inflightIPs[subnet.ID] = prevIPs - predictedIPsUsed
}
return zonalSubnets, nil
}
// UpdateInflightIPs is used to refresh the in-memory IP usage by adding back unused IPs after a CreateFleet response is returned
func (p *DefaultProvider) UpdateInflightIPs(createFleetInput *ec2.CreateFleetInput, createFleetOutput *ec2.CreateFleetOutput, instanceTypes []*cloudprovider.InstanceType,
subnets []*Subnet, capacityType string) {
p.Lock()
defer p.Unlock()
// Process the CreateFleetInput to pull out all the requested subnetIDs
fleetInputSubnets := lo.Compact(lo.Uniq(lo.FlatMap(createFleetInput.LaunchTemplateConfigs, func(req ec2types.FleetLaunchTemplateConfigRequest, _ int) []string {
return lo.Map(req.Overrides, func(override ec2types.FleetLaunchTemplateOverridesRequest, _ int) string {
return lo.FromPtr(override.SubnetId)
})
})))
// Process the CreateFleetOutput to pull out all the fulfilled subnetIDs
var fleetOutputSubnets []string
if createFleetOutput != nil {
fleetOutputSubnets = lo.Compact(lo.Uniq(lo.Map(createFleetOutput.Instances, func(fleetInstance ec2types.CreateFleetInstance, _ int) string {
if fleetInstance.LaunchTemplateAndOverrides == nil || fleetInstance.LaunchTemplateAndOverrides.Overrides == nil {
return ""
}
return lo.FromPtr(fleetInstance.LaunchTemplateAndOverrides.Overrides.SubnetId)
})))
}
// Find the subnets that were included in the input but not chosen by Fleet, so we need to add the inflight IPs back to them
subnetIDsToAddBackIPs, _ := lo.Difference(fleetInputSubnets, fleetOutputSubnets)
// Aggregate all the cached subnets ip address count
cachedAvailableIPAddressMap := lo.MapEntries(p.availableIPAddressCache.Items(), func(k string, v cache.Item) (string, int32) {
return k, v.Object.(int32)
})
// Update the inflight IP tracking of subnets stored in the cache that have not be synchronized since the initial
// deduction of IP addresses before the instance launch
for cachedSubnetID, cachedIPAddressCount := range cachedAvailableIPAddressMap {
if !lo.Contains(subnetIDsToAddBackIPs, cachedSubnetID) {
continue
}
originalSubnet, ok := lo.Find(subnets, func(subnet *Subnet) bool {
return subnet.ID == cachedSubnetID
})
if !ok {
continue
}
// If the cached subnet IP address count hasn't changed from the original subnet used to
// launch the instance, then we need to update the tracked IPs
if originalSubnet.AvailableIPAddressCount == cachedIPAddressCount {
// other IPs deducted were opportunistic and need to be readded since Fleet didn't pick those subnets to launch into
if ips, ok := p.inflightIPs[originalSubnet.ID]; ok {
minPods := p.minPods(instanceTypes, scheduling.NewRequirements(
scheduling.NewRequirement(karpv1.CapacityTypeLabelKey, corev1.NodeSelectorOpIn, capacityType),
scheduling.NewRequirement(corev1.LabelTopologyZone, corev1.NodeSelectorOpIn, originalSubnet.Zone),
))
p.inflightIPs[originalSubnet.ID] = ips + minPods
}
}
}
}
func (p *DefaultProvider) LivenessProbe(_ *http.Request) error {
p.Lock()
//nolint: staticcheck
p.Unlock()
return nil
}
func (p *DefaultProvider) minPods(instanceTypes []*cloudprovider.InstanceType, reqs scheduling.Requirements) int32 {
// filter for instance types available in the zone and capacity type being requested
filteredInstanceTypes := lo.Filter(instanceTypes, func(it *cloudprovider.InstanceType, _ int) bool {
return it.Offerings.Available().HasCompatible(reqs)
})
if len(filteredInstanceTypes) == 0 {
return 0
}
// Get minimum pods to use when selecting a subnet and deducting what will be launched
pods, _ := lo.MinBy(filteredInstanceTypes, func(i *cloudprovider.InstanceType, j *cloudprovider.InstanceType) bool {
return i.Capacity.Pods().Cmp(*j.Capacity.Pods()) < 0
}).Capacity.Pods().AsInt64()
//nolint:gosec
return int32(pods)
}
func getFilterSets(terms []v1.SubnetSelectorTerm) (res [][]ec2types.Filter) {
idFilter := ec2types.Filter{Name: aws.String("subnet-id")}
for _, term := range terms {
switch {
case term.ID != "":
idFilter.Values = append(idFilter.Values, term.ID)
default:
var filters []ec2types.Filter
for k, v := range term.Tags {
if v == "*" {
filters = append(filters, ec2types.Filter{
Name: aws.String("tag-key"),
Values: []string{k},
})
} else {
filters = append(filters, ec2types.Filter{
Name: aws.String(fmt.Sprintf("tag:%s", k)),
Values: []string{v},
})
}
}
res = append(res, filters)
}
}
if len(idFilter.Values) > 0 {
res = append(res, []ec2types.Filter{idFilter})
}
return res
}