cmd/utils/vmss.go (274 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
package utils
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"os/signal"
"regexp"
"strings"
"syscall"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice"
"github.com/Azure/go-autorest/autorest/to"
"github.com/Azure/kubectl-aks/cmd/utils/config"
"github.com/kinvolk/inspektor-gadget/pkg/k8sutil"
log "github.com/sirupsen/logrus"
metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
type OutputTruncate int
const (
OutputTruncateHead OutputTruncate = iota
OutputTruncateTail
BytesLimit = 4096
DefaultRunCommandTimeoutInSeconds = 300
)
type VirtualMachineScaleSetVM struct {
SubscriptionID string
NodeResourceGroup string
VMScaleSet string
InstanceID string
}
type RunCommandResult struct {
Stdout string
Stderr string
}
// ParseVMSSResourceID extracts elements from a given VMSS resource ID with format:
// /subscriptions/mySubID/resourceGroups/myRG/providers/myProvider/virtualMachineScaleSets/myVMSS/virtualMachines/myInsID
func ParseVMSSResourceID(id string, vm *VirtualMachineScaleSetVM) error {
const expectedItems int = 5
// This allows us to make resource ID (--id) option not case sentitive
id = strings.ToLower(id)
// Required because fmt.Sscanf expects space-separated values
idWithSpaces := strings.TrimSpace(strings.Replace(id, "/", " ", -1))
// We don't need the provider but fmt.Sscanf does not support "%*s" operator
// to read but prevent conversion. Therefore, read it and don't use it.
var provider string
n, err := fmt.Sscanf(idWithSpaces, "subscriptions %s resourcegroups %s providers %s virtualmachinescalesets %s virtualmachines %s",
&vm.SubscriptionID, &vm.NodeResourceGroup, &provider, &vm.VMScaleSet, &vm.InstanceID)
if err != nil {
return fmt.Errorf("error parsing provider ID %q: %w", id, err)
}
if n != expectedItems {
return fmt.Errorf("%d values retrieved while expecting %d when parsing id %s",
n, expectedItems, id)
}
return nil
}
// VirtualMachineScaleSetVMFromConfig returns a VirtualMachineScaleSetVM object
// it assumes that the config is set and valid
func VirtualMachineScaleSetVMFromConfig() (*VirtualMachineScaleSetVM, error) {
var vm VirtualMachineScaleSetVM
if node != "" {
// Before trying to get the resource ID from the API server, verify if
// the VMSS information of that node is already in the config file.
config := config.New()
if cc, ok := config.GetNodeConfig(node); ok {
log.Debugf("Using VMSS information from config for node %s", node)
vm.SubscriptionID = cc.GetString(SubscriptionIDKey)
vm.NodeResourceGroup = cc.GetString(NodeResourceGroupKey)
vm.VMScaleSet = cc.GetString(VMSSKey)
vm.InstanceID = cc.GetString(VMSSInstanceIDKey)
return &vm, nil
}
var err error
resourceID, err = GetNodeResourceID(context.TODO(), node)
if err != nil {
return nil, fmt.Errorf("retrieving Azure resource ID of node %s from API server: %w",
node, err)
}
if err = ParseVMSSResourceID(resourceID, &vm); err != nil {
return nil, fmt.Errorf("parsing Azure resource ID %s: %w", resourceID, err)
}
} else if resourceID != "" {
if err := ParseVMSSResourceID(resourceID, &vm); err != nil {
return nil, fmt.Errorf("parsing Azure resource ID %s: %w", resourceID, err)
}
} else {
vm.SubscriptionID = subscriptionID
vm.NodeResourceGroup = nodeResourceGroup
vm.VMScaleSet = vmss
vm.InstanceID = vmssInstanceID
}
return &vm, nil
}
func VirtualMachineScaleSetVMsViaKubeconfig() (map[string]*VirtualMachineScaleSetVM, error) {
clientset, err := k8sutil.NewClientsetFromConfigFlags(KubernetesConfigFlags)
if err != nil {
return nil, fmt.Errorf("creating Kubernetes client: %w", err)
}
nodes, err := clientset.CoreV1().Nodes().List(context.TODO(), metaV1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("listing nodes: %w", err)
}
vmssVMs := make(map[string]*VirtualMachineScaleSetVM)
if len(nodes.Items) > 0 {
for _, n := range nodes.Items {
var vm VirtualMachineScaleSetVM
if !strings.HasPrefix(n.Spec.ProviderID, "azure://") {
return nil, fmt.Errorf("node=%q doesn't seem to be an Azure VMSS VM", n.Name)
}
if err = ParseVMSSResourceID(strings.TrimPrefix(n.Spec.ProviderID, "azure://"), &vm); err != nil {
return nil, fmt.Errorf("parsing Azure resource ID %q: %w", n.Spec.ProviderID, err)
}
vmssVMs[n.Name] = &vm
}
}
return vmssVMs, nil
}
func VirtualMachineScaleSetVMsViaAzureAPI(subID, rg, clusterName string) (map[string]*VirtualMachineScaleSetVM, error) {
creds, err := GetCredentials()
if err != nil {
return nil, fmt.Errorf("getting credentials: %w", err)
}
ctx := context.Background()
aksClient, err := armcontainerservice.NewManagedClustersClient(subID, creds, nil)
if err != nil {
return nil, fmt.Errorf("creating AKS client: %w", err)
}
cluster, err := aksClient.Get(ctx, rg, clusterName, nil)
if err != nil {
return nil, fmt.Errorf("getting cluster: %w", err)
}
var nodePools []string
vmssClient, err := armcompute.NewVirtualMachineScaleSetsClient(subID, creds, nil)
if err != nil {
return nil, fmt.Errorf("creating VMSS client: %w", err)
}
nodePoolPager := vmssClient.NewListPager(to.String(cluster.Properties.NodeResourceGroup), nil)
for nodePoolPager.More() {
nextResult, err := nodePoolPager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("getting next page of node pools: %w", err)
}
for _, np := range nextResult.Value {
nodePools = append(nodePools, to.String(np.Name))
}
}
vmssVMs := make(map[string]*VirtualMachineScaleSetVM)
vmClient, err := armcompute.NewVirtualMachineScaleSetVMsClient(subID, creds, nil)
if err != nil {
return nil, fmt.Errorf("creating VMSS VMs client: %w", err)
}
for _, np := range nodePools {
instances, err := instancesForNodePool(ctx, vmClient, np, to.String(cluster.Properties.NodeResourceGroup))
if err != nil {
return nil, fmt.Errorf("getting instances for node pool %q: %w", np, err)
}
for _, instance := range instances {
vmssVMs[instanceName(instance)] = &VirtualMachineScaleSetVM{
SubscriptionID: subID,
VMScaleSet: np,
NodeResourceGroup: strings.ToLower(to.String(cluster.Properties.NodeResourceGroup)),
InstanceID: to.String(instance.InstanceID),
}
}
}
return vmssVMs, nil
}
func instancesForNodePool(ctx context.Context, vmClient *armcompute.VirtualMachineScaleSetVMsClient, pool, resourceGroup string) ([]*armcompute.VirtualMachineScaleSetVM, error) {
var instances []*armcompute.VirtualMachineScaleSetVM
pager := vmClient.NewListPager(resourceGroup, pool, nil)
for pager.More() {
nextPage, err := pager.NextPage(ctx)
if err != nil {
return nil, err
}
instances = append(instances, nextPage.Value...)
}
return instances, nil
}
// instanceName returns the instance name of the VMSS VM formatted as Kubernetes node name.
func instanceName(vm *armcompute.VirtualMachineScaleSetVM) string {
if vm.Properties.OSProfile == nil || vm.Properties.OSProfile.ComputerName == nil {
return to.String(vm.Name)
}
return strings.ToLower(to.String(vm.Properties.OSProfile.ComputerName))
}
func RunCommand(
ctx context.Context,
cred azcore.TokenCredential,
vm *VirtualMachineScaleSetVM,
command *string,
verbose bool,
timeout *int,
outputTruncate OutputTruncate,
) (
*RunCommandResult,
error,
) {
const (
commandID = "RunShellScript"
initialDelay = 15 * time.Second
pollingFreq = 2 * time.Second
)
if timeout == nil {
timeout = to.IntPtr(DefaultRunCommandTimeoutInSeconds)
}
client, err := armcompute.NewVirtualMachineScaleSetVMsClient(vm.SubscriptionID, cred, nil)
if err != nil {
return nil, fmt.Errorf("creating VMSS VMs client: %w", err)
}
// By default, the Azure API limits the output to the last 4,096 bytes. See
// https://learn.microsoft.com/en-us/azure/virtual-machines/linux/run-command#restrictions.
if outputTruncate == OutputTruncateTail {
*command = fmt.Sprintf("%s | head -c %d", *command, BytesLimit)
}
script := []*string{to.StringPtr(fmt.Sprintf("timeout %d sh -c '%s'", *timeout, *command))}
runCommand := armcompute.RunCommandInput{
CommandID: to.StringPtr(commandID),
Script: script,
}
if verbose {
b, _ := json.MarshalIndent(vm, "", " ")
fmt.Printf("Command: %s\nVirtual Machine Scale Set VM:\n%s\n\n", *command, string(b))
}
s := make(chan os.Signal, 1)
signal.Notify(s, os.Interrupt, syscall.SIGTERM)
go func() {
<-s
log.Warn("The requested command hasn't finished yet, hit 'Ctrl+C' again to exit anyway.")
log.Warn("However, please notice the command will continue running in the node anyway, " +
"and you will be unable to see the output or run another command until it finishes.")
<-s
os.Exit(1)
}()
DefaultSpinner.Start()
DefaultSpinner.Suffix = " Running..."
poller, err := client.BeginRunCommand(ctx, vm.NodeResourceGroup,
vm.VMScaleSet, vm.InstanceID, runCommand, nil)
if err != nil {
DefaultSpinner.Stop()
return nil, fmt.Errorf("begin running command: %w", err)
}
res, err := poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{Frequency: pollingFreq})
DefaultSpinner.Stop()
if err != nil {
return nil, fmt.Errorf("polling command response: %w", err)
}
if verbose {
b, _ := json.MarshalIndent(res, "", " ")
fmt.Printf("\nResponse:\n%s\n", string(b))
}
// TODO: Is it possible to have multiple values after using PollUntilDone()?
if len(res.Value) == 0 || res.Value[0] == nil {
return nil, errors.New("no response received after command execution")
}
val := res.Value[0]
// TODO: Isn't there a constant in the SDK to compare this?
if to.String(val.Code) != "ProvisioningState/succeeded" {
b, _ := json.MarshalIndent(res, "", " ")
return nil, fmt.Errorf("command execution didn't succeed:\n%s", string(b))
}
result, err := parseRunCommandMessage(to.String(val.Message))
if err != nil {
return nil, err
}
if outputTruncate == OutputTruncateTail && result.isTruncated() {
result.Stdout = fmt.Sprintf("%s... (truncated)\n", result.Stdout)
}
return result, nil
}
func parseRunCommandMessage(msg string) (*RunCommandResult, error) {
// Expected format: "Enable succeeded: <text>"
res := strings.TrimPrefix(msg, "Enable succeeded: ")
// Extract stdout and stderr from response.
// Expected format: "\n[stdout]\n<text>\n[stderr]\n<text>"
split := regexp.MustCompile(`\n\[(stdout|stderr)\]\n`).Split(res, -1)
if len(split) != 3 {
return nil, fmt.Errorf("couldn't parse response message:\n%s", res)
}
return &RunCommandResult{
Stdout: split[1],
Stderr: split[2],
}, nil
}
func (r *RunCommandResult) isTruncated() bool {
return len(r.Stdout)+len(r.Stderr) >= BytesLimit
}