pkg/providers/instance/instance.go (314 lines of code) (raw):
/*
Copyright (c) Microsoft Corporation.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package instance
import (
"context"
"fmt"
"regexp"
"strings"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4"
"github.com/azure/gpu-provisioner/pkg/utils"
"github.com/samber/lo"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/util/retry"
"k8s.io/klog/v2"
"knative.dev/pkg/logging"
"sigs.k8s.io/controller-runtime/pkg/client"
karpenterv1 "sigs.k8s.io/karpenter/pkg/apis/v1"
"sigs.k8s.io/karpenter/pkg/cloudprovider"
"sigs.k8s.io/karpenter/pkg/scheduling"
)
const (
LabelMachineType = "kaito.sh/machine-type"
NodeClaimCreationLabel = "kaito.sh/creation-timestamp"
// use self-defined layout in order to satisfy node label syntax
CreationTimestampLayout = "2006-01-02T15-04-05Z"
)
var (
KaitoNodeLabels = []string{"kaito.sh/workspace", "kaito.sh/ragengine"}
AgentPoolNameRegex = regexp.MustCompile(`^[a-z][a-z0-9]{0,11}$`)
)
type Provider struct {
azClient *AZClient
kubeClient client.Client
resourceGroup string
clusterName string
}
func NewProvider(
azClient *AZClient,
kubeClient client.Client,
resourceGroup string,
clusterName string,
) *Provider {
return &Provider{
azClient: azClient,
kubeClient: kubeClient,
resourceGroup: resourceGroup,
clusterName: clusterName,
}
}
// Create an instance given the constraints.
// instanceTypes should be sorted by priority for spot capacity type.
func (p *Provider) Create(ctx context.Context, nodeClaim *karpenterv1.NodeClaim) (*Instance, error) {
klog.InfoS("Instance.Create", "nodeClaim", klog.KObj(nodeClaim))
// We made a strong assumption here. The nodeClaim name should be a valid agent pool name without "-".
apName := nodeClaim.Name
if !AgentPoolNameRegex.MatchString(apName) {
//https://learn.microsoft.com/en-us/troubleshoot/azure/azure-kubernetes/aks-common-issues-faq#what-naming-restrictions-are-enforced-for-aks-resources-and-parameters-
return nil, fmt.Errorf("agentpool name(%s) is invalid, must match regex pattern: ^[a-z][a-z0-9]{0,11}$", apName)
}
var ap *armcontainerservice.AgentPool
err := retry.OnError(retry.DefaultBackoff, func(err error) bool {
return false
}, func() error {
instanceTypes := scheduling.NewNodeSelectorRequirementsWithMinValues(nodeClaim.Spec.Requirements...).Get("node.kubernetes.io/instance-type").Values()
if len(instanceTypes) == 0 {
return fmt.Errorf("nodeClaim spec has no requirement for instance type")
}
vmSize := instanceTypes[0]
apObj, apErr := newAgentPoolObject(vmSize, nodeClaim)
if apErr != nil {
return apErr
}
logging.FromContext(ctx).Debugf("creating Agent pool %s (%s)", apName, vmSize)
var err error
ap, err = createAgentPool(ctx, p.azClient.agentPoolsClient, p.resourceGroup, apName, p.clusterName, apObj)
if err != nil {
switch {
case strings.Contains(err.Error(), "Operation is not allowed because there's an in progress create node pool operation"):
// when gpu-provisioner restarted after crash for unknown reason, we may come across this error that agent pool creating
// is in progress, so we just need to wait node ready based on the apObj.
ap = &apObj
return nil
default:
logging.FromContext(ctx).Errorf("failed to create agent pool for nodeclaim(%s), %v", nodeClaim.Name, err)
return fmt.Errorf("agentPool.BeginCreateOrUpdate for %q failed: %w", apName, err)
}
}
logging.FromContext(ctx).Debugf("created agent pool %s", *ap.ID)
return nil
})
if err != nil {
return nil, err
}
instance, err := p.fromRegisteredAgentPoolToInstance(ctx, ap)
if instance == nil && err == nil {
// means the node object has not been found yet, we wait until the node is created
b := wait.Backoff{
Steps: 15,
Duration: 1 * time.Second,
Factor: 1.0,
Jitter: 0.1,
}
err = retry.OnError(b, func(err error) bool {
return true
}, func() error {
var e error
instance, e = p.fromRegisteredAgentPoolToInstance(ctx, ap)
if e != nil {
return e
}
if instance == nil {
return fmt.Errorf("fail to find the node object")
}
return nil
})
if err != nil {
return nil, err
}
}
return instance, err
}
func (p *Provider) Get(ctx context.Context, id string) (*Instance, error) {
apName, err := utils.ParseAgentPoolNameFromID(id)
if err != nil {
return nil, fmt.Errorf("getting agentpool name, %w", err)
}
apObj, err := getAgentPool(ctx, p.azClient.agentPoolsClient, p.resourceGroup, p.clusterName, apName)
if err != nil {
if strings.Contains(err.Error(), "Agent Pool not found") {
return nil, cloudprovider.NewNodeClaimNotFoundError(err)
}
logging.FromContext(ctx).Errorf("Get agentpool %q failed: %v", apName, err)
return nil, fmt.Errorf("agentPool.Get for %s failed: %w", apName, err)
}
return p.convertAgentPoolToInstance(ctx, apObj, id)
}
func (p *Provider) List(ctx context.Context) ([]*Instance, error) {
apList, err := listAgentPools(ctx, p.azClient.agentPoolsClient, p.resourceGroup, p.clusterName)
if err != nil {
logging.FromContext(ctx).Errorf("Listing agentpools failed: %v", err)
return nil, fmt.Errorf("agentPool.NewListPager failed: %w", err)
}
instances, err := p.fromAPListToInstances(ctx, apList)
return instances, cloudprovider.IgnoreNodeClaimNotFoundError(err)
}
func (p *Provider) Delete(ctx context.Context, apName string) error {
klog.InfoS("Instance.Delete", "agentpool name", apName)
err := deleteAgentPool(ctx, p.azClient.agentPoolsClient, p.resourceGroup, p.clusterName, apName)
if err != nil {
logging.FromContext(ctx).Errorf("Deleting agentpool %q failed: %v", apName, err)
return fmt.Errorf("agentPool.Delete for %q failed: %w", apName, err)
}
return nil
}
func (p *Provider) convertAgentPoolToInstance(ctx context.Context, apObj *armcontainerservice.AgentPool, id string) (*Instance, error) {
if apObj == nil || len(id) == 0 {
return nil, fmt.Errorf("agent pool or provider id is nil")
}
instanceLabels := lo.MapValues(apObj.Properties.NodeLabels, func(k *string, _ string) string {
return lo.FromPtr(k)
})
return &Instance{
Name: apObj.Name,
ID: to.Ptr(id),
Type: apObj.Properties.VMSize,
SubnetID: apObj.Properties.VnetSubnetID,
Tags: apObj.Properties.Tags,
State: apObj.Properties.ProvisioningState,
Labels: instanceLabels,
ImageID: apObj.Properties.NodeImageVersion,
}, nil
}
func (p *Provider) fromRegisteredAgentPoolToInstance(ctx context.Context, apObj *armcontainerservice.AgentPool) (*Instance, error) {
if apObj == nil {
return nil, fmt.Errorf("agent pool is nil")
}
nodes, err := p.getNodesByName(ctx, lo.FromPtr(apObj.Name))
if err != nil {
return nil, err
}
if len(nodes) == 0 || len(nodes) > 1 {
// NotFound is not considered as an error
// and AgentPool may create more than one instance, we need to wait agentPool remove
// the spare instance.
return nil, nil
}
// we only want to resolve providerID and construct instance based on AgentPool.
// there is no need to verify the node ready condition. so comment the following if condition.
// if node == nil || nodeutil.GetCondition(node, v1.NodeReady).Status != v1.ConditionTrue {
// // node is not found or not ready
// return nil, nil
// }
// It's need to wait node and providerID ready when create AgentPool,
// but there is no need to wait when termination controller lists all agentpools.
// because termination controller garbage leaked agentpools.
if len(nodes[0].Spec.ProviderID) == 0 {
// provider id is not found
return nil, nil
}
// tokens := strings.SplitAfter(node.Name, "-vmss") // remove the vm index "0000"
instanceLabels := lo.MapValues(apObj.Properties.NodeLabels, func(k *string, _ string) string {
return lo.FromPtr(k)
})
return &Instance{
Name: apObj.Name,
// ID: to.Ptr(fmt.Sprint("azure://", p.getVMSSNodeProviderID(lo.FromPtr(subID), tokens[0]))),
ID: to.Ptr(nodes[0].Spec.ProviderID),
Type: apObj.Properties.VMSize,
SubnetID: apObj.Properties.VnetSubnetID,
Tags: apObj.Properties.Tags,
State: apObj.Properties.ProvisioningState,
Labels: instanceLabels,
}, nil
}
// fromKaitoAgentPoolToInstance is used to convert agentpool that owned by kaito to Instance, and agentPools that have no
// associated node are also included in order to garbage leaked agentPools.
func (p *Provider) fromKaitoAgentPoolToInstance(ctx context.Context, apObj *armcontainerservice.AgentPool) (*Instance, error) {
if apObj == nil {
return nil, fmt.Errorf("agent pool is nil")
}
instanceLabels := lo.MapValues(apObj.Properties.NodeLabels, func(k *string, _ string) string {
return lo.FromPtr(k)
})
ins := &Instance{
Name: apObj.Name,
Type: apObj.Properties.VMSize,
SubnetID: apObj.Properties.VnetSubnetID,
Tags: apObj.Properties.Tags,
State: apObj.Properties.ProvisioningState,
Labels: instanceLabels,
}
nodes, err := p.getNodesByName(ctx, lo.FromPtr(apObj.Name))
if err != nil {
return nil, err
}
if len(nodes) == 1 && len(nodes[0].Spec.ProviderID) != 0 {
ins.ID = to.Ptr(nodes[0].Spec.ProviderID)
}
return ins, nil
}
func (p *Provider) fromAPListToInstances(ctx context.Context, apList []*armcontainerservice.AgentPool) ([]*Instance, error) {
instances := []*Instance{}
if len(apList) == 0 {
return instances, cloudprovider.NewNodeClaimNotFoundError(fmt.Errorf("agentpools not found"))
}
for index := range apList {
// skip agentPool that is not owned by kaito
if !agentPoolIsOwnedByKaito(apList[index]) {
continue
}
// skip agentPool which is not created from nodeclaim
if !agentPoolIsCreatedFromNodeClaim(apList[index]) {
continue
}
instance, err := p.fromKaitoAgentPoolToInstance(ctx, apList[index])
if err != nil {
return instances, err
}
if instance != nil {
instances = append(instances, instance)
}
}
if len(instances) == 0 {
return instances, cloudprovider.NewNodeClaimNotFoundError(fmt.Errorf("agentpools not found"))
}
return instances, nil
}
func newAgentPoolObject(vmSize string, nodeClaim *karpenterv1.NodeClaim) (armcontainerservice.AgentPool, error) {
taints := nodeClaim.Spec.Taints
taintsStr := []*string{}
for _, t := range taints {
taintsStr = append(taintsStr, to.Ptr(fmt.Sprintf("%s=%s:%s", t.Key, t.Value, t.Effect)))
}
scaleSetsType := armcontainerservice.AgentPoolTypeVirtualMachineScaleSets
// todo: why nodepool label is used here
labels := map[string]*string{karpenterv1.NodePoolLabelKey: to.Ptr("kaito")}
for k, v := range nodeClaim.Labels {
labels[k] = to.Ptr(v)
}
if strings.Contains(vmSize, "Standard_N") {
labels = lo.Assign(labels, map[string]*string{LabelMachineType: to.Ptr("gpu")})
} else {
labels = lo.Assign(labels, map[string]*string{LabelMachineType: to.Ptr("cpu")})
}
// NodeClaimCreationLabel is used for recording the create timestamp of agentPool resource.
// then used by garbage collection controller to cleanup orphan agentpool which lived more than 10min
labels[NodeClaimCreationLabel] = to.Ptr(nodeClaim.CreationTimestamp.UTC().Format(CreationTimestampLayout))
storage := &resource.Quantity{}
if nodeClaim.Spec.Resources.Requests != nil {
storage = nodeClaim.Spec.Resources.Requests.Storage()
}
var diskSizeGB int32
if storage.Value() <= 0 {
return armcontainerservice.AgentPool{}, fmt.Errorf("storage request of nodeclaim(%s) should be more than 0", nodeClaim.Name)
} else {
diskSizeGB = int32(storage.Value() >> 30)
}
return armcontainerservice.AgentPool{
Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{
NodeLabels: labels,
NodeTaints: taintsStr, //[]*string{to.Ptr("sku=gpu:NoSchedule")},
Type: to.Ptr(scaleSetsType),
VMSize: to.Ptr(vmSize),
OSType: to.Ptr(armcontainerservice.OSTypeLinux),
Count: to.Ptr(int32(1)),
OSDiskSizeGB: to.Ptr(diskSizeGB),
},
}, nil
}
func (p *Provider) getNodesByName(ctx context.Context, apName string) ([]*v1.Node, error) {
nodeList := &v1.NodeList{}
labelSelector := client.MatchingLabels{"agentpool": apName, "kubernetes.azure.com/agentpool": apName}
err := retry.OnError(retry.DefaultRetry, func(err error) bool {
return true
}, func() error {
return p.kubeClient.List(ctx, nodeList, labelSelector)
})
if err != nil {
return nil, err
}
return lo.ToSlicePtr(nodeList.Items), nil
}
func agentPoolIsOwnedByKaito(ap *armcontainerservice.AgentPool) bool {
if ap == nil || ap.Properties == nil {
return false
}
// when agentpool.NodeLabels includes labels from kaito, return true, if not, return false
for i := range KaitoNodeLabels {
if _, ok := ap.Properties.NodeLabels[KaitoNodeLabels[i]]; ok {
return true
}
}
return false
}
func agentPoolIsCreatedFromNodeClaim(ap *armcontainerservice.AgentPool) bool {
if ap == nil || ap.Properties == nil {
return false
}
// when agentpool.NodeLabels includes nodepool label, return true, if not, return false
if _, ok := ap.Properties.NodeLabels[karpenterv1.NodePoolLabelKey]; ok {
return true
}
return false
}