pkg/dev/portforward/pod_forwarder.go (198 lines of code) (raw):

// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one // or more contributor license agreements. Licensed under the Elastic License 2.0; // you may not use this file except in compliance with the Elastic License 2.0. package portforward import ( "context" "errors" "fmt" "net" "regexp" "strings" "sync" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/kubernetes" "sigs.k8s.io/controller-runtime/pkg/client/config" "github.com/elastic/cloud-on-k8s/v3/pkg/utils/k8s" utilsnet "github.com/elastic/cloud-on-k8s/v3/pkg/utils/net" ) // PodForwarder enables redirecting tcp connections through "kubectl port-forward" tooling type PodForwarder struct { network, addr string podNSN types.NamespacedName // clientset is used to stop the pod forwarder if the pod is deleted, may be set to nil to skip checking clientset *kubernetes.Clientset // initChan is used to wait for the port-forwarder to be set up before redirecting connections initChan chan struct{} // viaErr is set when there's an error during initialization viaErr error // viaAddr is the address that we use when redirecting connections viaAddr string // ephemeralPortFinder is used to find an available ephemeral port ephemeralPortFinder func() (string, error) // portForwarderFactory is used to facilitate testing without using the API portForwarderFactory PortForwarderFactory // dialerFunc is used to facilitate testing without making new connections dialerFunc dialerFunc } var _ Forwarder = &PodForwarder{} // PortForwarderFactory is a factory for port forwarders type PortForwarderFactory func( ctx context.Context, namespace, podName string, ports []string, readyChan chan struct{}, ) (PortForwarder, error) // PortForwarder is a port forwarder that may be started. type PortForwarder interface { ForwardPorts() error } // dialerFunc is a factory for connections type dialerFunc func(ctx context.Context, network, address string) (net.Conn, error) // NewPodForwarder returns a new initialized podForwarder func NewPodForwarder(ctx context.Context, network, addr string, clientset *kubernetes.Clientset) (*PodForwarder, error) { podNSN, err := parsePodAddr(ctx, addr, clientset) if err != nil { return nil, err } return &PodForwarder{ network: network, addr: addr, podNSN: *podNSN, clientset: clientset, initChan: make(chan struct{}), ephemeralPortFinder: utilsnet.GetRandomPort, portForwarderFactory: defaultPortForwarderFactory, dialerFunc: defaultDialerFunc, }, nil } // newDefaultKubernetesClientset creates a new Clientset func newDefaultKubernetesClientset() (*kubernetes.Clientset, error) { cfg, err := config.GetConfig() if err != nil { return nil, err } return kubernetes.NewForConfig(cfg) } // podDNSRegex matches pods FQDN such as {name}.{namespace}.pod var podDNSRegex = regexp.MustCompile(`^.+\..+$`) // podIPRegex matches any ipv4 address. var podIPv4Regex = regexp.MustCompile(`^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$`) // parsePodAddr parses the pod name and namespace from an address. func parsePodAddr(ctx context.Context, addr string, clientSet *kubernetes.Clientset) (*types.NamespacedName, error) { host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err } if podIPv4Regex.MatchString(host) { // we got an IP address // try to map it to a pod name and namespace return getPodWithIP(ctx, host, clientSet) } if podDNSRegex.MatchString(host) { // retrieve pod name and namespace from addr parts := strings.SplitN(host, ".", 4) if len(parts) <= 1 { return nil, fmt.Errorf("unsupported pod address format: %s", host) } if len(parts) == 2 || parts[2] == syntheticDNSSegment { // podname.ns[.pod] from service forwarder or direct call return &types.NamespacedName{Namespace: parts[1], Name: parts[0]}, nil } // podname.subdomain.ns return &types.NamespacedName{Namespace: parts[2], Name: parts[0]}, nil } return nil, fmt.Errorf("unsupported pod address format: %s", host) } // getPodWithIP requests the apiserver for pods with the given IP assigned. func getPodWithIP(ctx context.Context, ip string, clientSet *kubernetes.Clientset) (*types.NamespacedName, error) { pods, err := clientSet.CoreV1(). Pods(""). List(ctx, metav1.ListOptions{ FieldSelector: fmt.Sprintf("status.podIP=%s", ip), }) if err != nil { return nil, err } if pods == nil || len(pods.Items) == 0 { return nil, fmt.Errorf("pod with IP %s not found", ip) } nsn := k8s.ExtractNamespacedName(&(pods.Items[0].ObjectMeta)) return &nsn, nil } // defaultPortForwarderFactory is the default factory used for port forwarders outside of tests var defaultPortForwarderFactory PortForwarderFactory = func( ctx context.Context, namespace, podName string, ports []string, readyChan chan struct{}, ) (PortForwarder, error) { return newKubectlPortForwarder(ctx, namespace, podName, ports, readyChan) } // defaultDialerFunc is the default dialer function we use outside of tests var defaultDialerFunc dialerFunc = func(ctx context.Context, network, address string) (net.Conn, error) { var d net.Dialer return d.DialContext(ctx, network, address) } // DialContext connects to the podForwarder address using the provided context. func (f *PodForwarder) DialContext(ctx context.Context) (net.Conn, error) { // wait until we're initialized or context is done select { case <-f.initChan: case <-ctx.Done(): } // context has an error, so we can give up, most likely exceeded our timeout if ctx.Err() != nil { return nil, ctx.Err() } // we have an error to return if f.viaErr != nil { return nil, f.viaErr } log.V(2).Info("Redirecting dial call", "addr", f.addr, "via", f.viaAddr) return f.dialerFunc(ctx, f.network, f.viaAddr) } // Run starts a port forwarder and blocks until either the port forwarding fails or the context is done. func (f *PodForwarder) Run(ctx context.Context) error { log.V(2).Info("Running port-forwarder for", "addr", f.addr) defer log.V(2).Info("No longer running port-forwarder for", "addr", f.addr) // used as a safeguard to ensure we only close the init channel once initCloser := sync.Once{} // wrap this in a sync.Once because it will panic if it happens more than once // ensure that initChan is closed even if we were never ready. defer initCloser.Do(func() { close(f.initChan) }) // derive a new context so we can ensure the port-forwarding is stopped before we return and that we return as // soon as the port-forwarding stops, whichever occurs first runCtx, runCtxCancel := context.WithCancel(ctx) defer runCtxCancel() if f.clientset != nil { log.V(2).Info("Watching pod for changes", "namespace", f.podNSN.Namespace, "pod_name", f.podNSN.Name) w, err := f.clientset.CoreV1().Pods(f.podNSN.Namespace).Watch(ctx, metav1.ListOptions{ FieldSelector: fields.OneTermEqualSelector("metadata.name", f.podNSN.Name).String(), }) if err != nil { return fmt.Errorf("unable to watch pod %s for changes: %w", f.podNSN, err) } defer w.Stop() go func() { for { select { case evt := <-w.ResultChan(): if evt.Type == watch.Deleted || evt.Type == watch.Error || evt.Type == "" { log.V(2).Info( "Pod is deleted or watch failed/closed, closing pod forwarder", "namespace", f.podNSN.Namespace, "pod_name", f.podNSN.Name, ) runCtxCancel() return } case <-runCtx.Done(): return } } }() } _, port, err := net.SplitHostPort(f.addr) if err != nil { return err } // find an available local ephemeral port localPort, err := f.ephemeralPortFinder() if err != nil { return err } readyChan := make(chan struct{}) fwd, err := f.portForwarderFactory( runCtx, f.podNSN.Namespace, f.podNSN.Name, []string{localPort + ":" + port}, readyChan, ) if err != nil { return err } // wait for our context to be done or the port forwarder to become ready go func() { select { case <-runCtx.Done(): case <-readyChan: f.viaAddr = "127.0.0.1:" + localPort log.V(2).Info("Ready to redirect connections", "addr", f.addr, "via", f.viaAddr) // wrap this in a sync.Once because it will panic if it happens more than once, which it may if our // outer function returned just as readyChan was closed. initCloser.Do(func() { close(f.initChan) }) } }() err = fwd.ForwardPorts() f.viaErr = errors.New("not currently forwarding") return err }