controllers/daemon/staticgatewayconfiguration_controller.go (965 lines of code) (raw):

// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. package daemon import ( "bytes" "context" "fmt" "net" "os" "strconv" "strings" "github.com/containernetworking/plugins/pkg/ns" "github.com/vishvananda/netlink" "github.com/vishvananda/netlink/nl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" utiliptables "k8s.io/kubernetes/pkg/util/iptables" utilexec "k8s.io/utils/exec" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/controller-runtime/pkg/source" egressgatewayv1alpha1 "github.com/Azure/kube-egress-gateway/api/v1alpha1" "github.com/Azure/kube-egress-gateway/pkg/consts" "github.com/Azure/kube-egress-gateway/pkg/healthprobe" "github.com/Azure/kube-egress-gateway/pkg/imds" "github.com/Azure/kube-egress-gateway/pkg/netlinkwrapper" "github.com/Azure/kube-egress-gateway/pkg/netnswrapper" "github.com/Azure/kube-egress-gateway/pkg/utils/to" "github.com/Azure/kube-egress-gateway/pkg/wgctrlwrapper" ) var _ reconcile.Reconciler = &StaticGatewayConfigurationReconciler{} // StaticGatewayConfigurationReconciler reconciles gateway node network according to a StaticGatewayConfiguration object type StaticGatewayConfigurationReconciler struct { client.Client TickerEvents chan event.GenericEvent LBProbeServer *healthprobe.LBProbeServer Netlink netlinkwrapper.Interface NetNS netnswrapper.Interface IPTables utiliptables.Interface WgCtrl wgctrlwrapper.Interface } //+kubebuilder:rbac:groups=egressgateway.kubernetes.azure.com,resources=staticgatewayconfigurations,verbs=get;list;watch //+kubebuilder:rbac:groups=egressgateway.kubernetes.azure.com,resources=staticgatewayconfigurations/status,verbs=get;update;patch //+kubebuilder:rbac:groups=egressgateway.kubernetes.azure.com,resources=gatewayvmconfigurations,verbs=get;list;watch //+kubebuilder:rbac:groups=egressgateway.kubernetes.azure.com,resources=gatewayvmconfigurations/status,verbs=get //+kubebuilder:rbac:groups=core,namespace=kube-egress-gateway-system,resources=secrets,verbs=get;list;watch //+kubebuilder:rbac:groups=egressgateway.kubernetes.azure.com,resources=gatewaystatuses,verbs=get;list;watch;create;update;patch // Reconcile is part of the main kubernetes reconciliation loop which aims to // move the current state of the cluster closer to the desired state. // TODO(user): Modify the Reconcile function to compare the state specified by // the StaticGatewayConfiguration object against the actual cluster state, and then // perform operations to make the cluster state reflect the state specified by // the user. // // For more details, check Reconcile and its Result here: // - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.13.0/pkg/reconcile var ( nodeMeta *imds.InstanceMetadata lbMeta *imds.LoadBalancerMetadata nodeTags map[string]string ) func InitNodeMetadata() error { var err error nodeMeta, err = imds.GetInstanceMetadata() if err != nil { return err } lbMeta, err = imds.GetLoadBalancerMetadata() if err != nil { return err } if nodeMeta == nil || lbMeta == nil { return fmt.Errorf("failed to setup controller: nodeMeta or lbMeta is nil") } nodeTags = parseNodeTags() return nil } func (r *StaticGatewayConfigurationReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { log := log.FromContext(ctx) // Got an event from cleanup ticker if req.NamespacedName.Namespace == "" && req.NamespacedName.Name == "" { if err := r.cleanUp(ctx); err != nil { return ctrl.Result{}, fmt.Errorf("failed to clean up orphaned network configurations: %w", err) } return ctrl.Result{}, nil } // Fetch the StaticGatewayConfiguration instance. gwConfig := &egressgatewayv1alpha1.StaticGatewayConfiguration{} if err := r.Get(ctx, req.NamespacedName, gwConfig); err != nil { if apierrors.IsNotFound(err) { // Object not found, return. return ctrl.Result{}, nil } log.Error(err, "unable to fetch StaticGatewayConfiguration instance") return ctrl.Result{}, err } if !isReady(gwConfig) { // gateway setup hasn't completed yet return ctrl.Result{}, nil } if !applyToNode(gwConfig) { // gwConfig does not apply to this node return ctrl.Result{}, nil } if !gwConfig.ObjectMeta.DeletionTimestamp.IsZero() { if err := r.cleanUp(ctx); err != nil { return ctrl.Result{}, fmt.Errorf("failed to clean up deleted StaticGatewayConfiguration %s/%s: %w", gwConfig.Namespace, gwConfig.Name, err) } return ctrl.Result{}, nil } // Reconcile gateway configuration return ctrl.Result{}, r.reconcile(ctx, gwConfig) } // SetupWithManager sets up the controller with the Manager. func (r *StaticGatewayConfigurationReconciler) SetupWithManager(mgr ctrl.Manager) error { r.Netlink = netlinkwrapper.NewNetLink() r.NetNS = netnswrapper.NewNetNS() r.IPTables = utiliptables.New(utilexec.New(), utiliptables.ProtocolIPv4) r.WgCtrl = wgctrlwrapper.NewWgCtrl() controller, err := ctrl.NewControllerManagedBy(mgr). For(&egressgatewayv1alpha1.StaticGatewayConfiguration{}). // We need to watch GatewayVMConfiguration also, because vmSecondaryIP may change, e.g. duing upgrade // we can use EnqueueRequestForObject because GatewayVMConfiguration has the same namespace/name as StaticGatewayConfiguration Watches(&egressgatewayv1alpha1.GatewayVMConfiguration{}, &handler.EnqueueRequestForObject{}). Build(r) if err != nil { return err } return controller.Watch(source.Channel(r.TickerEvents, &handler.EnqueueRequestForObject{})) } func (r *StaticGatewayConfigurationReconciler) reconcile( ctx context.Context, gwConfig *egressgatewayv1alpha1.StaticGatewayConfiguration, ) error { log := log.FromContext(ctx) log.Info("Reconciling gateway configuration") // get wireguard private key from secret privateKey, err := r.getWireguardPrivateKey(ctx, gwConfig) if err != nil { return err } // add lb ip (if not exists) to eth0 if err := r.reconcileIlbIPOnHost(ctx, gwConfig.Status.GatewayServerProfile.Ip); err != nil { return err } // remove secondary ip from eth0 vmPrimaryIP, vmSecondaryIP, err := r.getVMIP(ctx, gwConfig) if err != nil { return err } if err := r.removeSecondaryIpFromHost(ctx, vmSecondaryIP); err != nil { return err } // avoid masquerading packets from gateway namespace, as they're already sNATed if err := r.ensureIPTablesChain( ctx, utiliptables.TableNAT, utiliptables.Chain("EGRESS-GATEWAY-SNAT"), // target chain utiliptables.ChainPostrouting, // source chain "kube-egress-gateway no MASQUERADE", nil); err != nil { return err } if err := r.ensureIPTablesChain( ctx, utiliptables.TableNAT, utiliptables.Chain(fmt.Sprintf("EGRESS-%s", strings.ReplaceAll(vmSecondaryIP, ".", "-"))), // target chain utiliptables.Chain("EGRESS-GATEWAY-SNAT"), // source chain fmt.Sprintf("kube-egress-gateway no sNAT packet from ip %s", vmSecondaryIP), [][]string{ {"-s", vmSecondaryIP + "/32", "-j", "ACCEPT"}, }); err != nil { return err } // configure gateway namespace (if not exists) if err := r.configureGatewayNamespace(ctx, gwConfig, privateKey, vmPrimaryIP, vmSecondaryIP); err != nil { return err } // update gateway status gwStatus := egressgatewayv1alpha1.GatewayConfiguration{ StaticGatewayConfiguration: fmt.Sprintf("%s/%s", gwConfig.Namespace, gwConfig.Name), InterfaceName: getWireguardInterfaceName(gwConfig), } if err := r.updateGatewayNodeStatus(ctx, gwStatus, true /* add */); err != nil { return err } if err := r.LBProbeServer.AddGateway(string(gwConfig.GetUID())); err != nil { return err } log.Info("Gateway configuration reconciled") return nil } func (r *StaticGatewayConfigurationReconciler) cleanUp(ctx context.Context) error { log := log.FromContext(ctx) log.Info("Cleaning up orphaned gateway network configurations") gwConfigList := &egressgatewayv1alpha1.StaticGatewayConfigurationList{} if err := r.List(ctx, gwConfigList); err != nil { return fmt.Errorf("failed to list staticGatewayConfigurations: %w", err) } existingWgLinks := make(map[string]struct{}) existingIPs := make(map[string]struct{}) hasActiveGateway := false for _, gwConfig := range gwConfigList.Items { if applyToNode(&gwConfig) && gwConfig.DeletionTimestamp.IsZero() { _, vmSecondaryIP, err := r.getVMIP(ctx, &gwConfig) if err != nil { log.Error(err, "failed to get VM secondaryIP during cleanup", "gwConfig", fmt.Sprintf("%s/%s", gwConfig.Namespace, gwConfig.Name)) continue } existingWgLinks[getWireguardInterfaceName(&gwConfig)] = struct{}{} existingIPs[vmSecondaryIP] = struct{}{} hasActiveGateway = true } } gwns, err := r.NetNS.GetNS(consts.GatewayNetnsName) if err != nil { return fmt.Errorf("failed to get network namespace %s: %w", consts.GatewayNetnsName, err) } defer gwns.Close() var links []netlink.Link var ips []netlink.Addr if err := gwns.Do(func(nn ns.NetNS) error { var err error links, err = r.Netlink.LinkList() if err != nil { return fmt.Errorf("failed to list links in gateway namespace: %w", err) } hostLink, err := r.Netlink.LinkByName(consts.HostLinkName) if err != nil { return fmt.Errorf("failed to get host link in gateway namespace: %w", err) } ips, err = r.Netlink.AddrList(hostLink, nl.FAMILY_ALL) if err != nil { return fmt.Errorf("failed to list addresses on host0 in gateway namespace: %w", err) } return nil }); err != nil { return err } for _, ip := range ips { if _, ok := existingIPs[ip.IP.String()]; !ok { log.Info("Removing orphaned IP", "ip", ip.IP.String()) if err := r.ensureDeleteIP(ctx, gwns, ip); err != nil { log.Error(err, fmt.Sprintf("failed to cleanup vmSecondaryIP %s", ip.IP.String())) } } } for _, link := range links { if strings.HasPrefix(link.Attrs().Name, consts.WiregaurdLinkNamePrefix) { if _, ok := existingWgLinks[link.Attrs().Name]; !ok { log.Info("Removing orphaned wireguard link", "link", link.Attrs().Name) if err := r.ensureDeleteLink(ctx, gwns, link); err != nil { log.Error(err, fmt.Sprintf("failed to cleanup wireguard link %s", link.Attrs().Name)) } } } } if !hasActiveGateway { log.Info("No active gateway found, cleaning up leftover network configurations") if err := r.reconcileIlbIPOnHost(ctx, ""); err != nil { return fmt.Errorf("failed to cleanup ILB IP on host: %w", err) } if err := r.removeIPTablesChains( ctx, utiliptables.TableNAT, []utiliptables.Chain{utiliptables.Chain("EGRESS-GATEWAY-SNAT")}, []utiliptables.Chain{utiliptables.ChainPostrouting}, []string{"kube-egress-gateway no MASQUERADE"}, ); err != nil { return fmt.Errorf("failed to delete iptables chain EGRESS-GATEWAY-SNAT: %w", err) } } log.Info("Network namespace cleanup completed") return nil } func (r *StaticGatewayConfigurationReconciler) ensureDeleteLink(ctx context.Context, gwns ns.NetNS, link netlink.Link) error { log := log.FromContext(ctx) linkName := link.Attrs().Name if err := gwns.Do(func(nn ns.NetNS) error { log.Info("Deleting link", "link", link.Attrs().Name) err := r.Netlink.LinkDel(link) if err != nil { return fmt.Errorf("failed to delete link %s: %w", linkName, err) } mark, err := getPacketMark(linkName) if err != nil { return err } log.Info("Removing iptables rules", "mark", mark) if err := r.removeIPTablesChains( ctx, utiliptables.TableNAT, []utiliptables.Chain{ utiliptables.Chain(fmt.Sprintf("EGRESS-GATEWAY-MARK-%d", mark)), utiliptables.Chain(fmt.Sprintf("EGRESS-GATEWAY-SNAT-%d", mark)), }, // target chain []utiliptables.Chain{ utiliptables.ChainPrerouting, utiliptables.ChainPostrouting, }, // source chain []string{ fmt.Sprintf("kube-egress-gateway mark packets from gateway link %s", linkName), fmt.Sprintf("kube-egress-gateway sNAT packets from gateway link %s", linkName), }, ); err != nil { return fmt.Errorf("failed to cleanup iptables rules for link %s and mark %d: %w", linkName, mark, err) } return nil }); err != nil { return err } // update gateway status gwStatus := egressgatewayv1alpha1.GatewayConfiguration{ InterfaceName: link.Attrs().Name, } if err := r.updateGatewayNodeStatus(ctx, gwStatus, false /* add */); err != nil { return err } if err := r.LBProbeServer.RemoveGateway(link.Attrs().Alias); err != nil { return err } return nil } func (r *StaticGatewayConfigurationReconciler) ensureDeleteIP(ctx context.Context, gwns ns.NetNS, ip netlink.Addr) error { log := log.FromContext(ctx) if err := gwns.Do(func(nn ns.NetNS) error { log.Info("Deleting IP from host0", "ip", ip.IP.String()) hostLink, err := r.Netlink.LinkByName(consts.HostLinkName) if err != nil { return fmt.Errorf("failed to get host link in gateway namespace: %w", err) } if err := r.Netlink.AddrDel(hostLink, &ip); err != nil { return fmt.Errorf("failed to delete IP %s: %w", ip.IP.String(), err) } return nil }); err != nil { return err } routes, err := r.Netlink.RouteList(nil, nl.FAMILY_ALL) if err != nil { return fmt.Errorf("failed to list routes in host namespace: %w", err) } for _, route := range routes { route := route if route.Dst != nil && route.Dst.IP.Equal(ip.IP) { log.Info("Deleting route in host namespace to vmSecondaryIP", "route", route) if err := r.Netlink.RouteDel(&route); err != nil { return fmt.Errorf("failed to delete route to %s: %w", ip.IP.String(), err) } } } log.Info("Deleting no-sNAT rule for vmSecondaryIP", "ip", ip.IP.String()) if err := r.removeIPTablesChains( ctx, utiliptables.TableNAT, []utiliptables.Chain{utiliptables.Chain(fmt.Sprintf("EGRESS-%s", strings.ReplaceAll(ip.IP.String(), ".", "-")))}, // target chain []utiliptables.Chain{utiliptables.Chain("EGRESS-GATEWAY-SNAT")}, // source chain []string{fmt.Sprintf("kube-egress-gateway no sNAT packet from ip %s", ip.IP.String())}, ); err != nil { return fmt.Errorf("failed to clean up no-sNAT rule for vmSecondaryIP %s: %w", ip.IP.String(), err) } return nil } func (r *StaticGatewayConfigurationReconciler) getWireguardPrivateKey( ctx context.Context, gwConfig *egressgatewayv1alpha1.StaticGatewayConfiguration, ) (*wgtypes.Key, error) { secretKey := &types.NamespacedName{ Namespace: gwConfig.Status.PrivateKeySecretRef.Namespace, Name: gwConfig.Status.PrivateKeySecretRef.Name, } secret := &corev1.Secret{} if err := r.Get(ctx, *secretKey, secret); err != nil { return nil, fmt.Errorf("failed to retrieve wireguard private key secret: %w", err) } wgPrivateKeyByte, ok := secret.Data[consts.WireguardPrivateKeyName] if !ok { return nil, fmt.Errorf("failed to retrieve private key from secret %s/%s", secretKey.Namespace, secretKey.Name) } wgPrivateKey, err := wgtypes.ParseKey(string(wgPrivateKeyByte)) if err != nil { return nil, err } return &wgPrivateKey, nil } func (r *StaticGatewayConfigurationReconciler) getVMIP( ctx context.Context, gwConfig *egressgatewayv1alpha1.StaticGatewayConfiguration, ) (string, string, error) { log := log.FromContext(ctx) nodeName := nodeMeta.Compute.OSProfile.ComputerName var primaryIP, secondaryIP string // Fetch the StaticGatewayConfiguration instance. vmConfig := &egressgatewayv1alpha1.GatewayVMConfiguration{} if err := r.Get(ctx, types.NamespacedName{Namespace: gwConfig.Namespace, Name: gwConfig.Name}, vmConfig); err != nil { return "", "", err } // this can happen in cleanup process when vmConfig is not ready yet if vmConfig.Status == nil { return "", "", fmt.Errorf("status is nil for GatewayVMConfiguration %s/%s", vmConfig.Namespace, vmConfig.Name) } for _, vmProfile := range vmConfig.Status.GatewayVMProfiles { if vmProfile.NodeName == nodeName { primaryIP = vmProfile.PrimaryIP secondaryIP = vmProfile.SecondaryIP break } } if primaryIP == "" || secondaryIP == "" { return "", "", fmt.Errorf("failed to find primary or secondary IP for node %s", nodeName) } log.Info("Found primary and secondary IP for node", "nodeName", nodeName, "primaryIP", primaryIP, "secondaryIP", secondaryIP) return primaryIP, secondaryIP, nil } func isReady(gwConfig *egressgatewayv1alpha1.StaticGatewayConfiguration) bool { wgProfile := gwConfig.Status.GatewayServerProfile return gwConfig.Status.EgressIpPrefix != "" && wgProfile.Ip != "" && wgProfile.Port != 0 && wgProfile.PublicKey != "" && wgProfile.PrivateKeySecretRef != nil } func applyToNode(gwConfig *egressgatewayv1alpha1.StaticGatewayConfiguration) bool { if gwConfig.Spec.GatewayNodepoolName != "" { name, ok := nodeTags[consts.AKSNodepoolTagKey] return ok && strings.EqualFold(name, gwConfig.Spec.GatewayNodepoolName) } else { vmssProfile := gwConfig.Spec.GatewayVmssProfile return strings.EqualFold(vmssProfile.VmssName, nodeMeta.Compute.VMScaleSetName) && strings.EqualFold(vmssProfile.VmssResourceGroup, nodeMeta.Compute.ResourceGroupName) } } func parseNodeTags() map[string]string { tags := make(map[string]string) tagStrs := strings.Split(nodeMeta.Compute.Tags, ";") for _, tag := range tagStrs { kv := strings.Split(tag, ":") if len(kv) == 2 { tags[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1]) } } return tags } func (r *StaticGatewayConfigurationReconciler) reconcileIlbIPOnHost(ctx context.Context, ilbIP string) error { log := log.FromContext(ctx) eth0, err := r.Netlink.LinkByName("eth0") if err != nil { return fmt.Errorf("failed to retrieve link eth0: %w", err) } if len(nodeMeta.Network.Interface) == 0 || len(nodeMeta.Network.Interface[0].IPv4.Subnet) == 0 { return fmt.Errorf("imds does not provide subnet information about the node") } prefix, err := strconv.Atoi(nodeMeta.Network.Interface[0].IPv4.Subnet[0].Prefix) if err != nil { return fmt.Errorf("failed to retrieve and parse prefix: %w", err) } addresses, err := r.Netlink.AddrList(eth0, nl.FAMILY_ALL) if err != nil { return fmt.Errorf("failed to retrieve IP addresses for eth0: %w", err) } if ilbIP == "" { // cleanup process for _, address := range addresses { if address.Label == consts.ILBIPLabel { log.Info("Removing ILB IP from eth0", "ilb IP", address.IPNet.String()) if err := r.Netlink.AddrDel(eth0, &address); err != nil { return fmt.Errorf("failed to delete ILB IP from eth0: %w", err) } } } return nil } ilbIpCidr := fmt.Sprintf("%s/%d", ilbIP, prefix) ilbIpNet, err := netlink.ParseIPNet(ilbIpCidr) if err != nil { return fmt.Errorf("failed to parse ILB IP address: %s", ilbIpCidr) } addressPresent := false for _, address := range addresses { if address.IPNet.IP.Equal(ilbIpNet.IP) { addressPresent = true break } } if !addressPresent { log.Info("Adding ILB IP to eth0", "ilb IP", ilbIpCidr) if err := r.Netlink.AddrAdd(eth0, &netlink.Addr{ IPNet: ilbIpNet, Label: consts.ILBIPLabel, }); err != nil { return fmt.Errorf("failed to add ILB IP to eth0: %w", err) } } return nil } func (r *StaticGatewayConfigurationReconciler) removeSecondaryIpFromHost(ctx context.Context, ip string) error { log := log.FromContext(ctx) eth0, err := r.Netlink.LinkByName("eth0") if err != nil { return fmt.Errorf("failed to retrieve link eth0: %w", err) } addresses, err := r.Netlink.AddrList(eth0, nl.FAMILY_ALL) if err != nil { return fmt.Errorf("failed to retrieve IP addresses for eth0: %w", err) } for _, address := range addresses { if address.IP.String() == ip { log.Info("Removing secondary IP from eth0", "secondary_ip", address.IP.String()) if err := r.Netlink.AddrDel(eth0, &address); err != nil { return fmt.Errorf("failed to remove secondary ip from eth0: %w", err) } } } return nil } func (r *StaticGatewayConfigurationReconciler) configureGatewayNamespace( ctx context.Context, gwConfig *egressgatewayv1alpha1.StaticGatewayConfiguration, privateKey *wgtypes.Key, vmPrimaryIP string, vmSecondaryIP string, ) error { gwns, err := r.NetNS.GetNS(consts.GatewayNetnsName) if err != nil { return fmt.Errorf("failed to get network namespace %s: %w", consts.GatewayNetnsName, err) } defer gwns.Close() if err := r.reconcileWireguardLink(ctx, gwns, gwConfig, privateKey); err != nil { return err } if err := r.reconcileVethPair(ctx, gwns, vmPrimaryIP, vmSecondaryIP); err != nil { return err } return gwns.Do(func(nn ns.NetNS) error { looplink, err := r.Netlink.LinkByName("lo") if err != nil { return fmt.Errorf("failed to retrieve link lo: %w", err) } if err := r.Netlink.LinkSetUp(looplink); err != nil { return fmt.Errorf("failed to set lo up: %w", err) } linkName := getWireguardInterfaceName(gwConfig) mark, err := getPacketMark(linkName) if err != nil { return err } if err := r.ensureIPTablesChain( ctx, utiliptables.TableNAT, utiliptables.Chain(fmt.Sprintf("EGRESS-GATEWAY-MARK-%d", mark)), // target chain utiliptables.ChainPrerouting, // source chain fmt.Sprintf("kube-egress-gateway mark packets from gateway link %s", linkName), [][]string{ {"-i", linkName, "-j", "CONNMARK", "--set-mark", fmt.Sprintf("%d", mark)}, }); err != nil { return err } if err := r.ensureIPTablesChain( ctx, utiliptables.TableNAT, utiliptables.Chain(fmt.Sprintf("EGRESS-GATEWAY-SNAT-%d", mark)), // target chain utiliptables.ChainPostrouting, // source chain fmt.Sprintf("kube-egress-gateway sNAT packets from gateway link %s", linkName), [][]string{ {"-o", consts.HostLinkName, "-m", "connmark", "--mark", fmt.Sprintf("%d", mark), "-j", "SNAT", "--to-source", vmSecondaryIP}, }); err != nil { return err } return nil }) } func (r *StaticGatewayConfigurationReconciler) reconcileWireguardLink( ctx context.Context, gwns ns.NetNS, gwConfig *egressgatewayv1alpha1.StaticGatewayConfiguration, privateKey *wgtypes.Key, ) error { log := log.FromContext(ctx) linkName := getWireguardInterfaceName(gwConfig) var wgLink netlink.Link var err error if err = gwns.Do(func(nn ns.NetNS) error { wgLink, err = r.Netlink.LinkByName(linkName) if err != nil { if _, ok := err.(netlink.LinkNotFoundError); !ok { return fmt.Errorf("failed to get wireguard link in gateway namespace: %w", err) } wgLink = nil } return nil }); err != nil { return err } if wgLink == nil { log.Info("Creating wireguard link") if err := r.createWireguardLink(gwns, linkName, string(gwConfig.GetUID())); err != nil { return fmt.Errorf("failed to create wireguard link: %w", err) } } return gwns.Do(func(nn ns.NetNS) error { wgLink, err := r.Netlink.LinkByName(linkName) if err != nil { return fmt.Errorf("failed to get wireguard link in gateway namespace after creation: %w", err) } gwIP, _ := netlink.ParseIPNet(consts.GatewayIP) gwLinkAddr := netlink.Addr{ IPNet: gwIP, } wgLinkAddrs, err := r.Netlink.AddrList(wgLink, nl.FAMILY_ALL) if err != nil { return fmt.Errorf("failed to retrieve address list from wireguard link: %w", err) } foundLink := false for _, addr := range wgLinkAddrs { if addr.Equal(gwLinkAddr) { log.Info("Found wireguard link address") foundLink = true break } } if !foundLink { log.Info("Adding wireguard link address") err = r.Netlink.AddrAdd(wgLink, &gwLinkAddr) if err != nil { return fmt.Errorf("failed to add wireguard link address: %w", err) } } err = r.Netlink.LinkSetUp(wgLink) if err != nil { return fmt.Errorf("failed to set wireguard link up: %w", err) } wgClient, err := r.WgCtrl.New() if err != nil { return fmt.Errorf("failed to create wgctrl client: %w", err) } defer func() { _ = wgClient.Close() }() wgConfig := wgtypes.Config{ ListenPort: to.Ptr(int(gwConfig.Status.Port)), PrivateKey: privateKey, } device, err := wgClient.Device(linkName) if err != nil { return fmt.Errorf("failed to get wireguard link configuration: %w", err) } if device.PrivateKey.String() != wgConfig.PrivateKey.String() || device.ListenPort != to.Val(wgConfig.ListenPort) { log.Info("Updating wireguard link config", "orig port", device.ListenPort, "cur port", to.Val(wgConfig.ListenPort), "private key difference", device.PrivateKey.String() != wgConfig.PrivateKey.String()) err = wgClient.ConfigureDevice(linkName, wgConfig) if err != nil { return fmt.Errorf("failed to add peer to wireguard link: %w", err) } } return nil }) } func (r *StaticGatewayConfigurationReconciler) createWireguardLink(gwns ns.NetNS, linkName, linkAlias string) error { succeed := false attr := netlink.NewLinkAttrs() attr.Name = linkName attr.Alias = linkAlias wg := &netlink.Wireguard{LinkAttrs: attr} err := r.Netlink.LinkAdd(wg) if err != nil { return fmt.Errorf("failed to create wireguard link: %w", err) } defer func() { if !succeed { _ = r.Netlink.LinkDel(wg) } }() wgLink, err := r.Netlink.LinkByName(linkName) if err != nil { return fmt.Errorf("failed to get wireguard link in host namespace: %w", err) } if err := r.Netlink.LinkSetNsFd(wgLink, int(gwns.Fd())); err != nil { return fmt.Errorf("failed to move wireguard link to gateway namespace: %w", err) } succeed = true return nil } func (r *StaticGatewayConfigurationReconciler) reconcileVethPair( ctx context.Context, gwns ns.NetNS, vmPrimaryIP string, vmSecondaryIP string, ) error { log := log.FromContext(ctx) if err := r.reconcileVethPairInHost(ctx, gwns, vmSecondaryIP); err != nil { return fmt.Errorf("failed to reconcile veth pair in host namespace: %w", err) } return gwns.Do(func(nn ns.NetNS) error { hostLink, err := r.Netlink.LinkByName(consts.HostLinkName) if err != nil { return fmt.Errorf("failed to get host link in gateway namespace: %w", err) } _, snatIPNet, err := net.ParseCIDR(vmSecondaryIP + "/32") if err != nil { return fmt.Errorf("failed to parse SNAT IP(%s) for host interface: %w", vmSecondaryIP+"/32", err) } hostLinkAddr := netlink.Addr{IPNet: snatIPNet} hostLinkAddrs, err := r.Netlink.AddrList(hostLink, nl.FAMILY_ALL) if err != nil { return fmt.Errorf("failed to retrieve address list from wireguard link: %w", err) } foundLink := false for _, addr := range hostLinkAddrs { if addr.Equal(hostLinkAddr) { log.Info("Found host link address in gateway namespace") foundLink = true break } } if !foundLink { log.Info("Adding host link address in gateway namespace") err = r.Netlink.AddrAdd(hostLink, &hostLinkAddr) if err != nil { return fmt.Errorf("failed to add host link address in gateway namespace: %w", err) } } err = r.Netlink.LinkSetUp(hostLink) if err != nil { return fmt.Errorf("failed to set host link up: %w", err) } _, vmSnatCidr, err := net.ParseCIDR(vmPrimaryIP + "/32") if err != nil { return fmt.Errorf("failed to parse CIDR %s/32: %w", vmPrimaryIP+"/32", err) } err = r.addOrReplaceRoute(ctx, &netlink.Route{ LinkIndex: hostLink.Attrs().Index, Scope: netlink.SCOPE_LINK, Dst: vmSnatCidr, }) if err != nil { return fmt.Errorf("failed to create route to VM primary IP %s via gateway interface: %w", vmPrimaryIP, err) } err = r.addOrReplaceRoute(ctx, &netlink.Route{ LinkIndex: hostLink.Attrs().Index, Scope: netlink.SCOPE_UNIVERSE, Dst: nil, Gw: net.ParseIP(vmPrimaryIP), }) if err != nil { return fmt.Errorf("failed to create default route via %s: %w", vmPrimaryIP, err) } return nil }) } func (r *StaticGatewayConfigurationReconciler) reconcileVethPairInHost( ctx context.Context, gwns ns.NetNS, snatIP string, ) error { log := log.FromContext(ctx) succeed := false la := netlink.NewLinkAttrs() la.Name = consts.HostVethLinkName mainLink, err := r.Netlink.LinkByName(la.Name) if _, ok := err.(netlink.LinkNotFoundError); ok { log.Info("Creating veth pair in host namespace") veth := &netlink.Veth{ LinkAttrs: la, PeerName: consts.HostLinkName, } err := r.Netlink.LinkAdd(veth) if err != nil { return fmt.Errorf("failed to add veth pair: %w", err) } defer func() { if !succeed { _ = r.Netlink.LinkDel(veth) } }() mainLink, err = r.Netlink.LinkByName(la.Name) if err != nil { return fmt.Errorf("failed to get veth link in host namespace after creation: %w", err) } } else if err != nil { return fmt.Errorf("failed to get veth link in host namespace: %w", err) } err = r.Netlink.LinkSetUp(mainLink) if err != nil { return fmt.Errorf("failed to set veth link in host namespace up: %w", err) } _, snatIPNet, err := net.ParseCIDR(snatIP + "/32") if err != nil { return fmt.Errorf("failed to parse SNAT IP %s: %w", snatIP+"/32", err) } route := &netlink.Route{ LinkIndex: mainLink.Attrs().Index, Scope: netlink.SCOPE_UNIVERSE, Dst: snatIPNet, } if err = r.addOrReplaceRoute(ctx, route); err != nil { return fmt.Errorf("failed to create route to SNAT IP %s via gateway interface: %w", snatIP, err) } defer func() { if !succeed { _ = r.Netlink.RouteDel(route) } }() hostLink, err := r.Netlink.LinkByName(consts.HostLinkName) if err == nil { if err := r.Netlink.LinkSetNsFd(hostLink, int(gwns.Fd())); err != nil { return fmt.Errorf("failed to move veth peer link to gateway namespace: %w", err) } } else if _, ok := err.(netlink.LinkNotFoundError); !ok { return fmt.Errorf("failed to get veth peer link in host namespace: %w", err) } succeed = true return nil } func (r *StaticGatewayConfigurationReconciler) addOrReplaceRoute(ctx context.Context, route *netlink.Route) error { log := log.FromContext(ctx) equalRoute := func(r1, r2 *netlink.Route) bool { size1, _ := to.Val(r1.Dst).Mask.Size() size2, _ := to.Val(r2.Dst).Mask.Size() return r1.LinkIndex == r2.LinkIndex && r1.Scope == r2.Scope && to.Val(r1.Dst).IP.Equal(to.Val(r2.Dst).IP) && size1 == size2 && r1.Gw.Equal(r2.Gw) } routes, err := r.Netlink.RouteList(nil, nl.FAMILY_ALL) if err != nil { return err } foundRoute := false for i := range routes { if equalRoute(&routes[i], route) { foundRoute = true } } if !foundRoute { log.Info("Adding new route", "route", *route) if err = r.Netlink.RouteReplace(route); err != nil { return err } } return nil } func (r *StaticGatewayConfigurationReconciler) ensureIPTablesChain( ctx context.Context, table utiliptables.Table, targetChain utiliptables.Chain, sourceChain utiliptables.Chain, jumpRuleComment string, chainRules [][]string, ) error { log := log.FromContext(ctx) // ensure target chain exists log.Info("Ensuring iptables chain", "table", table, "target chain", targetChain) if _, err := r.IPTables.EnsureChain(table, targetChain); err != nil { return fmt.Errorf("failed to ensure chain %s in table %s: %w", targetChain, table, err) } // ensure jump rule exists, we use EnsureRule because we do not want to flush all rules in the source chain log.Info("Ensuring jump rule", "source chain", sourceChain) if _, err := r.IPTables.EnsureRule(utiliptables.Prepend, table, sourceChain, "-m", "comment", "--comment", jumpRuleComment, "-j", string(targetChain)); err != nil { return fmt.Errorf("failed to ensure jump rule from chain %s to chain %s in table %s: %w", sourceChain, targetChain, table, err) } if len(chainRules) == 0 { return nil } // ensure all rules in the target chain atomically lines := bytes.NewBuffer(nil) writeLine(lines, "*"+string(table)) writeLine(lines, utiliptables.MakeChainLine(targetChain)) for _, rule := range chainRules { writeRule(lines, string(utiliptables.Append), targetChain, rule...) } writeLine(lines, "COMMIT") log.Info("Restoring rules", "rules", lines.String()) if err := r.IPTables.RestoreAll(lines.Bytes(), utiliptables.NoFlushTables, utiliptables.NoRestoreCounters); err != nil { return fmt.Errorf("failed to restore rules in chain %s in table %s: %w", targetChain, table, err) } return nil } func (r *StaticGatewayConfigurationReconciler) removeIPTablesChains( ctx context.Context, table utiliptables.Table, targetChains []utiliptables.Chain, sourceChains []utiliptables.Chain, jumpRuleComments []string, ) error { log := log.FromContext(ctx) iptablesData := bytes.NewBuffer(nil) if err := r.IPTables.SaveInto(table, iptablesData); err != nil { return fmt.Errorf("failed to save iptables data for table %s: %w", table, err) } existingChains := utiliptables.GetChainsFromTable(iptablesData.Bytes()) for i, targetChain := range targetChains { sourceChain := sourceChains[i] jumpRuleComment := jumpRuleComments[i] if _, ok := existingChains[targetChain]; ok { // delete jump rule first log.Info("Deleting jump rule", "source chain", sourceChain, "target chain", targetChain) if err := r.IPTables.DeleteRule(table, sourceChain, "-m", "comment", "--comment", jumpRuleComment, "-j", string(targetChain)); err != nil { return fmt.Errorf("failed to delete jump rule from chain %s to chain %s in table %s: %w", sourceChain, targetChain, table, err) } log.Info("Flushing and deleting chain", "table", table, "target chain", targetChain) lines := bytes.NewBuffer(nil) writeLine(lines, "*"+string(table)) writeLine(lines, utiliptables.MakeChainLine(targetChain)) writeLine(lines, "-X", string(targetChain)) writeLine(lines, "COMMIT") if err := r.IPTables.Restore(table, lines.Bytes(), utiliptables.NoFlushTables, utiliptables.NoRestoreCounters); err != nil { return fmt.Errorf("failed to restore iptables table %s: %w", table, err) } } } return nil } func (r *StaticGatewayConfigurationReconciler) updateGatewayNodeStatus( ctx context.Context, gwConfig egressgatewayv1alpha1.GatewayConfiguration, add bool, ) error { log := log.FromContext(ctx) gwStatusKey := types.NamespacedName{ Namespace: os.Getenv(consts.PodNamespaceEnvKey), Name: os.Getenv(consts.NodeNameEnvKey), } gwStatus := &egressgatewayv1alpha1.GatewayStatus{} if err := r.Get(ctx, gwStatusKey, gwStatus); err != nil { if !apierrors.IsNotFound(err) { log.Error(err, "failed to get existing gateway status object %s/%s", gwStatusKey.Namespace, gwStatusKey.Name) return err } else { if !add { // ignore creating object during cleanup return nil } // gwStatus does not exist, create a new one log.Info(fmt.Sprintf("Creating new gateway status(%s/%s)", gwStatusKey.Namespace, gwStatusKey.Name)) node := &corev1.Node{} if err := r.Get(ctx, types.NamespacedName{Name: os.Getenv(consts.NodeNameEnvKey)}, node); err != nil { return fmt.Errorf("failed to get current node: %w", err) } gwStatus := &egressgatewayv1alpha1.GatewayStatus{ ObjectMeta: metav1.ObjectMeta{ Name: gwStatusKey.Name, Namespace: gwStatusKey.Namespace, }, Spec: egressgatewayv1alpha1.GatewayStatusSpec{ ReadyGatewayConfigurations: []egressgatewayv1alpha1.GatewayConfiguration{gwConfig}, }, } if err := controllerutil.SetOwnerReference(node, gwStatus, r.Client.Scheme()); err != nil { return fmt.Errorf("failed to set gwStatus owner reference to node: %w", err) } log.Info("Creating new gateway status object") if err := r.Create(ctx, gwStatus); err != nil { return fmt.Errorf("failed to create gwStatus object: %w", err) } } } else { changed := false found := false for i, gwConf := range gwStatus.Spec.ReadyGatewayConfigurations { if gwConf.InterfaceName == gwConfig.InterfaceName { if !add { changed = true gwStatus.Spec.ReadyGatewayConfigurations = append(gwStatus.Spec.ReadyGatewayConfigurations[:i], gwStatus.Spec.ReadyGatewayConfigurations[i+1:]...) } found = true break } } if add && !found { gwStatus.Spec.ReadyGatewayConfigurations = append(gwStatus.Spec.ReadyGatewayConfigurations, gwConfig) changed = true } if !add { for i := len(gwStatus.Spec.ReadyPeerConfigurations) - 1; i >= 0; i = i - 1 { if gwStatus.Spec.ReadyPeerConfigurations[i].InterfaceName == gwConfig.InterfaceName { changed = true gwStatus.Spec.ReadyPeerConfigurations = append(gwStatus.Spec.ReadyPeerConfigurations[:i], gwStatus.Spec.ReadyPeerConfigurations[i+1:]...) } } } if changed { log.Info("Updating gateway status object") if err := r.Update(ctx, gwStatus); err != nil { return fmt.Errorf("failed to update gwStatus object: %w", err) } } } return nil } func getWireguardInterfaceName(gwConfig *egressgatewayv1alpha1.StaticGatewayConfiguration) string { return consts.WiregaurdLinkNamePrefix + fmt.Sprintf("%d", gwConfig.Status.Port) } func getPacketMark(linkName string) (int, error) { mark, err := strconv.Atoi(strings.TrimPrefix(linkName, consts.WiregaurdLinkNamePrefix)) if err != nil { return -1, fmt.Errorf("failed to parse mark from link name (%s): %w", linkName, err) } return mark, err } // Similar syntax to utiliptables.Interface.EnsureRule, except you don't pass a table // (you must write these rules under the line with the table name) func writeRule(lines *bytes.Buffer, position string, chain utiliptables.Chain, args ...string) { fullArgs := append([]string{position, string(chain)}, args...) writeLine(lines, fullArgs...) } // Join all words with spaces, terminate with newline and write to buf. func writeLine(lines *bytes.Buffer, words ...string) { lines.WriteString(strings.Join(words, " ") + "\n") }