pkg/utils/utils.go (370 lines of code) (raw):

package utils import ( "crypto/sha1" "encoding/binary" "encoding/hex" "fmt" "net" "strings" "unsafe" "github.com/aws/aws-network-policy-agent/api/v1alpha1" "github.com/go-logr/logr" multierror "github.com/hashicorp/go-multierror" "github.com/vishvananda/netlink" corev1 "k8s.io/api/core/v1" ) var ( TCP_PROTOCOL_NUMBER = 6 UDP_PROTOCOL_NUMBER = 17 SCTP_PROTOCOL_NUMBER = 132 ICMP_PROTOCOL_NUMBER = 1 RESERVED_IP_PROTOCOL_NUMBER = 255 // 255 is a reserved protocol value in the IP header ANY_IP_PROTOCOL = 254 TRIE_KEY_LENGTH = 8 TRIE_V6_KEY_LENGTH = 20 TRIE_VALUE_LENGTH = 288 BPF_PROGRAMS_PIN_PATH_DIRECTORY = "/sys/fs/bpf/globals/aws/programs/" BPF_MAPS_PIN_PATH_DIRECTORY = "/sys/fs/bpf/globals/aws/maps/" TC_INGRESS_PROG = "handle_ingress" TC_EGRESS_PROG = "handle_egress" TC_INGRESS_MAP = "ingress_map" TC_EGRESS_MAP = "egress_map" TC_INGRESS_POD_STATE_MAP = "ingress_pod_state_map" TC_EGRESS_POD_STATE_MAP = "egress_pod_state_map" CATCH_ALL_PROTOCOL corev1.Protocol = "ANY_IP_PROTOCOL" DEFAULT_CLUSTER_NAME = "k8s-cluster" ErrFileExists = "file exists" ErrInvalidFilterList = "failed to get filter list" ErrMissingFilter = "no active filter to detach" ) // NetworkPolicyEnforcingMode is the mode of network policy enforcement type NetworkPolicyEnforcingMode string const ( // Strict : strict network policy enforcement Strict NetworkPolicyEnforcingMode = "strict" // Standard :standard network policy enforcement Standard NetworkPolicyEnforcingMode = "standard" ) // IsValidNetworkPolicyEnforcingMode checks if the input string matches any of the enum values func IsValidNetworkPolicyEnforcingMode(input string) bool { switch strings.ToLower(input) { case string(Strict), string(Standard): return true default: return false } } // IsStrictMode checks if NP enforcing mode is strict func IsStrictMode(input string) bool { return strings.ToLower(input) == string(Strict) } // IsStandardMode checks if NP enforcing mode is standard func IsStandardMode(input string) bool { return strings.ToLower(input) == string(Standard) } func GetProtocol(protocolNum int) string { protocolStr := "UNKNOWN" if protocolNum == TCP_PROTOCOL_NUMBER { protocolStr = "TCP" } else if protocolNum == UDP_PROTOCOL_NUMBER { protocolStr = "UDP" } else if protocolNum == SCTP_PROTOCOL_NUMBER { protocolStr = "SCTP" } else if protocolNum == ICMP_PROTOCOL_NUMBER { protocolStr = "ICMP" } else if protocolNum == RESERVED_IP_PROTOCOL_NUMBER { protocolStr = "RESERVED" } else if protocolNum == ANY_IP_PROTOCOL { protocolStr = "ANY PROTOCOL" } return protocolStr } var getLinkByNameFunc = netlink.LinkByName type VerdictType int const ( DENY VerdictType = iota ACCEPT EXPIRED_DELETED ) func (verdictType VerdictType) Index() int { return int(verdictType) } func GetPodNamespacedName(podName, podNamespace string) string { return podName + podNamespace } func GetPodIdentifier(podName, podNamespace string, log logr.Logger) string { if strings.Contains(podName, ".") { log.Info("Replacing '.' character with '_' for pod pin path.") podName = strings.Replace(podName, ".", "_", -1) } podIdentifierPrefix := podName if strings.Contains(string(podName), "-") { tmpName := strings.Split(podName, "-") podIdentifierPrefix = strings.Join(tmpName[:len(tmpName)-1], "-") } return podIdentifierPrefix + "-" + podNamespace } func GetPodIdentifierFromBPFPinPath(pinPath string) (string, string) { pinPathName := strings.Split(pinPath, "/") podIdentifier := strings.Split(pinPathName[7], "_") return podIdentifier[0], podIdentifier[2] } func GetBPFPinPathFromPodIdentifier(podIdentifier string, direction string) string { progName := TC_INGRESS_PROG if direction == "egress" { progName = TC_EGRESS_PROG } pinPath := BPF_PROGRAMS_PIN_PATH_DIRECTORY + podIdentifier + "_" + progName return pinPath } func GetBPFMapPinPathFromPodIdentifier(podIdentifier string, direction string) string { mapName := TC_INGRESS_MAP if direction == "egress" { mapName = TC_EGRESS_MAP } pinPath := BPF_MAPS_PIN_PATH_DIRECTORY + podIdentifier + "_" + mapName return pinPath } func GetPodStateBPFMapPinPathFromPodIdentifier(podIdentifier string, direction string) string { mapName := TC_INGRESS_POD_STATE_MAP if direction == "egress" { mapName = TC_EGRESS_POD_STATE_MAP } pinPath := BPF_MAPS_PIN_PATH_DIRECTORY + podIdentifier + "_" + mapName return pinPath } func GetPolicyEndpointIdentifier(policyName, policyNamespace string) string { return policyName + policyNamespace } func GetParentNPNameFromPEName(policyEndpointName string) string { return policyEndpointName[0:strings.LastIndex(policyEndpointName, "-")] } func getHostLinkByName(name string) (netlink.Link, error) { return getLinkByNameFunc(name) } func GetHostVethName(podName, podNamespace string, interfacePrefixes []string, logger logr.Logger) string { var interfaceName string var errors error h := sha1.New() h.Write([]byte(fmt.Sprintf("%s.%s", podNamespace, podName))) for _, prefix := range interfacePrefixes { interfaceName = fmt.Sprintf("%s%s", prefix, hex.EncodeToString(h.Sum(nil))[:11]) if _, err := getHostLinkByName(interfaceName); err == nil { logger.Info("host veth interface found", "interface name", interfaceName) return interfaceName } else { errors = multierror.Append(errors, fmt.Errorf("failed to find link %s: %w", interfaceName, err)) } } logger.Error(errors, "Not found any interface starting with prefixes and the hash", "prefixes searched", interfacePrefixes, "hash", hex.EncodeToString(h.Sum(nil))[:11]) return "" } func ComputeTrieKey(n net.IPNet, isIPv6Enabled bool) []byte { prefixLen, _ := n.Mask.Size() var key []byte if isIPv6Enabled { // Key format: Prefix length (4 bytes) followed by 16 byte IP key = make([]byte, TRIE_V6_KEY_LENGTH) } else { // Key format: Prefix length (4 bytes) followed by 4 byte IP key = make([]byte, TRIE_KEY_LENGTH) } binary.LittleEndian.PutUint32(key[0:4], uint32(prefixLen)) copy(key[4:], n.IP) return key } func ComputeTrieValue(l4Info []v1alpha1.Port, log logr.Logger, allowAll, denyAll bool) []byte { var startPort, endPort, protocol int value := make([]byte, TRIE_VALUE_LENGTH) startOffset := 0 if len(l4Info) == 0 { allowAll = true } if allowAll || denyAll { protocol = deriveProtocolValue(v1alpha1.Port{}, allowAll, denyAll) binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(protocol)) startOffset += 4 binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(startPort)) startOffset += 4 binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(endPort)) startOffset += 4 log.Info("L4 values: ", "protocol: ", protocol, "startPort: ", startPort, "endPort: ", endPort) } for _, l4Entry := range l4Info { if startOffset >= TRIE_VALUE_LENGTH { log.Error(nil, "No.of unique port/protocol combinations supported for a single endpoint exceeded the supported maximum of 24") return value } endPort = 0 startPort = 0 protocol = deriveProtocolValue(l4Entry, allowAll, denyAll) if l4Entry.Port != nil { startPort = int(*l4Entry.Port) } if l4Entry.EndPort != nil { endPort = int(*l4Entry.EndPort) } log.Info("L4 values: ", "protocol: ", protocol, "startPort: ", startPort, "endPort: ", endPort) binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(protocol)) startOffset += 4 binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(startPort)) startOffset += 4 binary.LittleEndian.PutUint32(value[startOffset:startOffset+4], uint32(endPort)) startOffset += 4 } return value } func deriveProtocolValue(l4Info v1alpha1.Port, allowAll, denyAll bool) int { protocol := TCP_PROTOCOL_NUMBER //ProtocolTCP if denyAll { return RESERVED_IP_PROTOCOL_NUMBER } if allowAll { return ANY_IP_PROTOCOL } if l4Info.Protocol == nil { return protocol //Protocol defaults TCP if not specified } if *l4Info.Protocol == corev1.ProtocolUDP { protocol = UDP_PROTOCOL_NUMBER } else if *l4Info.Protocol == corev1.ProtocolSCTP { protocol = SCTP_PROTOCOL_NUMBER } else if *l4Info.Protocol == CATCH_ALL_PROTOCOL { protocol = ANY_IP_PROTOCOL } return protocol } func IsFileExistsError(error string) bool { if error == ErrFileExists { return true } return false } func IsInvalidFilterListError(error string) bool { errCode := strings.Split(error, ":") if errCode[0] == ErrInvalidFilterList { return true } return false } func IsMissingFilterError(error string) bool { errCode := strings.Split(error, "-") if errCode[0] == ErrMissingFilter { return true } return false } func IsCatchAllIPEntry(ipAddr string) bool { ipSplit := strings.Split(ipAddr, "/") if ipSplit[1] == "0" { //if ipSplit[0] == "0.0.0.0" && ipSplit[1] == "0" { return true } return false } func IsNodeIP(nodeIP string, ipCidr string) bool { ipAddr, _, _ := net.ParseCIDR(ipCidr) if net.ParseIP(nodeIP).Equal(ipAddr) { return true } return false } func IsNonHostCIDR(ipAddr string) bool { ipSplit := strings.Split(ipAddr, "/") //Ignore Catch All IP entry as well if ipSplit[1] != "32" && ipSplit[1] != "128" && ipSplit[1] != "0" { return true } return false } func ConvByteArrayToIP(ipInInt uint32) string { hexIPString := fmt.Sprintf("%x", ipInInt) if len(hexIPString)%2 != 0 { hexIPString = "0" + hexIPString } byteData, _ := hex.DecodeString(hexIPString) reverseByteData := reverseByteArray(byteData) return strings.Trim(strings.Join(strings.Fields(fmt.Sprint(reverseByteData)), "."), "[]") } func reverseByteArray(input []byte) []byte { if len(input) == 0 { return input } return append(reverseByteArray(input[1:]), input[0]) } func ConvIntToIPv4(ipaddr uint32) net.IP { ip := make(net.IP, 4) binary.LittleEndian.PutUint32(ip, ipaddr) return ip } func ConvIPv4ToInt(ipaddr net.IP) uint32 { return uint32(ipaddr[0])<<24 | uint32(ipaddr[1])<<16 | uint32(ipaddr[2])<<8 | uint32(ipaddr[3]) } func ConvIntToIPv4NetworkOrder(ipaddr uint32) net.IP { ip := make(net.IP, 4) binary.BigEndian.PutUint32(ip, ipaddr) return ip } func ConvByteToIPv6(ipaddr [16]byte) net.IP { ip := net.IP(ipaddr[:]) return ip } func ConvIPv6ToByte(ipaddr net.IP) []byte { ipaddrBytes := ipaddr.To16() return ipaddrBytes } type ConntrackKeyV6 struct { Source_ip [16]byte Source_port uint16 _ uint16 //Padding Dest_ip [16]byte Dest_port uint16 Protocol uint8 _ uint8 //Padding Owner_ip [16]byte //16 } type ConntrackKey struct { Source_ip uint32 Source_port uint16 _ uint16 //Padding Dest_ip uint32 Dest_port uint16 Protocol uint8 _ uint8 //Padding Owner_ip uint32 } type ConntrackVal struct { Value uint8 } func ConvConntrackV6ToByte(key ConntrackKeyV6) []byte { ipSize := unsafe.Sizeof(key) byteArray := (*[unsafe.Sizeof(key)]byte)(unsafe.Pointer(&key)) byteSlice := byteArray[:ipSize] return byteSlice } func ConvByteToConntrackV6(keyByte []byte) ConntrackKeyV6 { var v6key ConntrackKeyV6 byteArray := (*[unsafe.Sizeof(v6key)]byte)(unsafe.Pointer(&v6key)) copy(byteArray[:], keyByte) return v6key } func CopyV6Bytes(dest *[16]byte, src [16]byte) { for i := 0; i < len(src); i++ { dest[i] = src[i] } } type BPFTrieKey struct { PrefixLen uint32 IP uint32 } type BPFTrieKeyV6 struct { PrefixLen uint32 IP [16]byte } type BPFTrieVal struct { Protocol uint32 StartPort uint32 EndPort uint32 } func ConvTrieV6ToByte(key BPFTrieKeyV6) []byte { ipSize := unsafe.Sizeof(key) byteArray := (*[20]byte)(unsafe.Pointer(&key)) byteSlice := byteArray[:ipSize] return byteSlice } func ConvByteToTrieV6(keyByte []byte) BPFTrieKeyV6 { var v6key BPFTrieKeyV6 byteArray := (*[unsafe.Sizeof(v6key)]byte)(unsafe.Pointer(&v6key)) copy(byteArray[:], keyByte) return v6key }