pkg/dev/portforward/service_forwarder.go (127 lines of code) (raw):
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License 2.0;
// you may not use this file except in compliance with the Elastic License 2.0.
package portforward
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"strings"
corev1 "k8s.io/api/core/v1"
discoveryv1 "k8s.io/api/discovery/v1"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/intstr"
netutils "k8s.io/utils/net"
"sigs.k8s.io/controller-runtime/pkg/client"
)
const syntheticDNSSegment = "pod"
// ServiceForwarder forwards one port of a service
type ServiceForwarder struct {
network, addr string
serviceNSN types.NamespacedName
// client is used to look up the service and pods selected by the service during dialing
client client.Client
store *ForwarderStore
// podForwarderFactory enables injecting a custom forwarder factory in tests
podForwarderFactory ForwarderFactory
}
var _ Forwarder = &ServiceForwarder{}
// defaultPodForwarderFactory is the default pod forwarder factory used outside of tests
var defaultPodForwarderFactory = ForwarderFactory(func(ctx context.Context, network, addr string) (Forwarder, error) {
clientset, err := newDefaultKubernetesClientset()
if err != nil {
return nil, err
}
return NewPodForwarder(ctx, network, addr, clientset)
})
// NewServiceForwarder returns a new initialized service forwarder
func NewServiceForwarder(client client.Client, network, addr string) (*ServiceForwarder, error) {
serviceNSN, err := parseServiceAddr(addr)
if err != nil {
return nil, err
}
return &ServiceForwarder{
network: network,
addr: addr,
client: client,
serviceNSN: *serviceNSN,
store: NewForwarderStore(),
podForwarderFactory: defaultPodForwarderFactory,
}, nil
}
// parseServiceAddr parses the service name and namespace from a connection address
func parseServiceAddr(addr string) (*types.NamespacedName, error) {
// services generally look like this (as FQDN): {name}.{namespace}.svc
parts := strings.SplitN(addr, ".", 3)
if len(parts) <= 2 {
return nil, fmt.Errorf("unsupported service address format: %s", addr)
}
return &types.NamespacedName{Namespace: parts[1], Name: parts[0]}, nil
}
// Run starts the service forwarder, blocking until it's done
func (f *ServiceForwarder) Run(ctx context.Context) error {
// TODO: /could/ consider snipping connections here when pods turn unready, but that does not match the default
// Service behavior
<-ctx.Done()
return nil
}
// DialContext dials one of the ready pods behind this service forwarder.
//
// As an approximation to load balancing, a random ready pod will be chosen for each dialing attempt.
func (f *ServiceForwarder) DialContext(ctx context.Context) (net.Conn, error) {
_, servicePortStr, err := net.SplitHostPort(f.addr)
if err != nil {
return nil, err
}
servicePort, err := netutils.ParsePort(servicePortStr, false)
if err != nil {
return nil, err
}
service := corev1.Service{}
if err := f.client.Get(ctx, f.serviceNSN, &service); err != nil {
return nil, err
}
// TODO: support named ports? how it's supposed to work is not quite clear atm, and we don't use it ourselves
// so this is deferred to later
targetPort := intstr.FromInt(0)
for _, port := range service.Spec.Ports {
if port.Port == int32(servicePort) {
// default to using the same port between the service and the target
targetPort = intstr.FromInt(int(port.Port))
// if .TargetPort is non-0, we use that
if port.TargetPort.IntValue() != 0 {
targetPort = port.TargetPort
}
break
}
}
if targetPort.IntValue() == 0 {
return nil, fmt.Errorf("service is not listening on port: %d", servicePort)
}
endpoints := discoveryv1.EndpointSliceList{}
listOps := &client.ListOptions{
LabelSelector: labels.SelectorFromSet(labels.Set{discoveryv1.LabelServiceName: service.Name}),
Namespace: service.Namespace,
}
if err := f.client.List(ctx, &endpoints, listOps); err != nil {
return nil, err
}
var podTargets []*corev1.ObjectReference
for _, endpointSlice := range endpoints.Items {
foundPort := false
for _, port := range endpointSlice.Ports {
if port.Port == nil {
continue
}
foundPort = *port.Port == int32(targetPort.IntValue())
if foundPort {
break
}
}
if !foundPort {
// Port is not found in the EndpointSlice, try the next one.
continue
}
for _, endpoint := range endpointSlice.Endpoints {
if endpoint.Conditions.Ready == nil || !*endpoint.Conditions.Ready {
// Do not forward to a pod that is not ready.
// Note that if spec.publishNotReadyAddresses is set to "true", then `ready` is always true:
// https://kubernetes.io/docs/concepts/services-networking/endpoint-slices/#ready
continue
}
if endpoint.TargetRef.Kind != "Pod" {
continue
}
podTargets = append(podTargets, endpoint.TargetRef)
}
}
if len(podTargets) == 0 {
return nil, errors.New("no pod addresses found in service endpoints")
}
pod := podTargets[rand.Intn(len(podTargets))] //nolint:gosec
// this should match a supported format of parsePodAddr(addr string)
podAddr := fmt.Sprintf("%s.%s.%s:%s", pod.Name, pod.Namespace, syntheticDNSSegment, targetPort.String())
forwarder, err := f.store.GetOrCreateForwarder(f.network, podAddr, f.podForwarderFactory)
if err != nil {
return nil, err
}
return forwarder.DialContext(ctx)
}