portforward/portforward.go (117 lines of code) (raw):

// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. package portforward import ( "context" "fmt" "net/http" "net/url" "time" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" kubepf "k8s.io/client-go/tools/portforward" "k8s.io/client-go/transport/spdy" "k8s.io/klog/v2" ) // PodPortForwarder is used to forward traffic to specific pod's TCP port from // local listener. type PodPortForwarder struct { // targetPort is the target TCP port. targetPort uint16 // portforwardURL is the pod's portforward URL. portforwardURL *url.URL // restCfg is used to create spdy transport. restCfg *rest.Config portForwarder *kubepf.PortForwarder } // NewPodPortForwarder return a new instance of PodPortForwarder. func NewPodPortForwarder(kubeCfgPath string, namespace, podName string, targetPort uint16) (*PodPortForwarder, error) { restCfg, err := clientcmd.BuildConfigFromFlags("", kubeCfgPath) if err != nil { return nil, err } restCfg.ContentType = "application/vnd.kubernetes.protobuf" restCli, err := kubernetes.NewForConfig(restCfg) if err != nil { return nil, err } if err := ensurePodIsRunning(restCli, namespace, podName); err != nil { return nil, err } u := restCli.CoreV1().RESTClient().Post(). Namespace(namespace). Resource("pods"). Name(podName). SubResource("portforward").URL() return &PodPortForwarder{ targetPort: targetPort, portforwardURL: u, restCfg: restCfg, }, nil } // Start is to start local listener to forward traffic. func (pf *PodPortForwarder) Start() error { transport, upgrader, err := spdy.RoundTripperFor(pf.restCfg) if err != nil { return fmt.Errorf("failed to create spdy transport: %w", err) } dialer := spdy.NewDialer( upgrader, &http.Client{Transport: transport}, "POST", pf.portforwardURL, ) startCh := make(chan struct{}) // pick available local port randomly. kubePortForwarder, err := kubepf.New( dialer, []string{fmt.Sprintf("0:%d", pf.targetPort)}, nil, startCh, &debugLogger{}, &debugLogger{}, ) if err != nil { return fmt.Errorf("failed to init kube port forward: %w", err) } errCh := make(chan error, 1) go func() { errCh <- kubePortForwarder.ForwardPorts() }() select { case <-startCh: case err := <-errCh: return fmt.Errorf("failed to start kube port forward: %w", err) case <-time.After(120 * time.Second): return fmt.Errorf("timeout to start kube port forward") } pf.portForwarder = kubePortForwarder return nil } // GetLocalPort returns the local listener's port. func (pf *PodPortForwarder) GetLocalPort() (uint16, error) { if pf.portForwarder == nil { return 0, fmt.Errorf("kube port forwarder doesn't start") } ports, err := pf.portForwarder.GetPorts() if err != nil { return 0, fmt.Errorf("failed to get local port: %w", err) } return ports[0].Local, nil } // Stop stops port forward. func (pf *PodPortForwarder) Stop() { defer klog.Flush() if pf.portForwarder != nil { pf.portForwarder.Close() } } // ensurePodIsRunning is to check if the target pod is still running. func ensurePodIsRunning(restCli kubernetes.Interface, namespace, podName string) error { pod, err := restCli.CoreV1(). Pods(namespace). Get(context.TODO(), podName, metav1.GetOptions{}) if err != nil { return fmt.Errorf("failed to ensure if %s in %s exists: %w", podName, namespace, err) } if pod.Status.Phase != corev1.PodRunning { return fmt.Errorf("unable to forward port because pod is not running (status=%s)", pod.Status.Phase) } return nil } type debugLogger struct{} func (l *debugLogger) Write(data []byte) (int, error) { klog.V(2).InfoS(string(data)) return len(data), nil }