pkg/cni/routes/routes.go (151 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package routes
import (
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strconv"
"github.com/containernetworking/cni/pkg/types"
current "github.com/containernetworking/cni/pkg/types/100"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netlink/nl"
"golang.org/x/sys/unix"
"github.com/Azure/kube-egress-gateway/pkg/consts"
"github.com/Azure/kube-egress-gateway/pkg/iptableswrapper"
"github.com/Azure/kube-egress-gateway/pkg/netlinkwrapper"
)
type runner struct {
netlink netlinkwrapper.Interface
iptables iptableswrapper.Interface
}
var routesRunner runner
func init() {
routesRunner = runner{
netlink: netlinkwrapper.NewNetLink(),
iptables: iptableswrapper.NewIPTables(),
}
}
func SetPodRoutes(ifName string, exceptionCidrs []string, defaultToGateway bool, sysctlDir string, result *current.Result) error {
eth0Link, err := routesRunner.netlink.LinkByName("eth0")
if err != nil {
return fmt.Errorf("failed to retrieve eth0 interface: %w", err)
}
wgLink, err := routesRunner.netlink.LinkByName(ifName)
if err != nil {
return fmt.Errorf("failed to retrieve wireguard interface: %w", err)
}
routes, err := routesRunner.netlink.RouteList(eth0Link, nl.FAMILY_ALL)
if err != nil {
return fmt.Errorf("failed to list all routes on eth0: %w", err)
}
var defaultRoute *netlink.Route
for _, route := range routes {
route := route
if route.Family == nl.FAMILY_V4 && (route.Dst == nil || route.Dst.String() == "0.0.0.0/0") {
defaultRoute = &route
}
}
if defaultRoute == nil {
return errors.New("failed to find default route")
}
eth0RouteTmpl := netlink.Route{
Gw: defaultRoute.Gw,
LinkIndex: eth0Link.Attrs().Index,
Protocol: unix.RTPROT_STATIC,
}
wgRouteTmpl := netlink.Route{
Gw: nil,
Via: &netlink.Via{
Addr: net.ParseIP("fe80::1"),
AddrFamily: nl.FAMILY_V6,
},
LinkIndex: wgLink.Attrs().Index,
Scope: netlink.SCOPE_UNIVERSE,
Family: nl.FAMILY_V4,
}
if defaultToGateway {
// 1. removes existing routes
// 2. add original default route gateway to eth0
// 3. routes exceptional cidrs (traffic avoiding gateway) to base interface (eth0)
// 4. add default route to wireguard interface
for _, route := range routes {
if err := routesRunner.netlink.RouteDel(&route); err != nil {
return fmt.Errorf("failed to delete route (%s): %w", route, err)
}
}
result.Routes = nil
gatewayDestination := net.IPNet{IP: defaultRoute.Gw, Mask: net.CIDRMask(32, 32)}
err = routesRunner.netlink.RouteReplace(&netlink.Route{
Dst: &gatewayDestination,
LinkIndex: eth0Link.Attrs().Index,
Scope: netlink.SCOPE_LINK,
})
if err != nil {
return fmt.Errorf("failed to add original gateway route: %w", err)
}
result.Routes = append(result.Routes, &types.Route{Dst: gatewayDestination})
_, defaultRouteCidr, _ := net.ParseCIDR("0.0.0.0/0")
wgDefaultRoute := wgRouteTmpl
wgDefaultRoute.Dst = defaultRouteCidr
result.Routes = append(result.Routes, &types.Route{Dst: *defaultRouteCidr, GW: net.ParseIP("fe80::1")})
err = routesRunner.netlink.RouteReplace(&wgDefaultRoute)
if err != nil {
return fmt.Errorf("failed to add default wireguard route (%s): %w", wgDefaultRoute, err)
}
}
for _, exception := range exceptionCidrs {
_, cidr, err := net.ParseCIDR(exception)
if err != nil {
return fmt.Errorf("failed to parse cidr (%s): %w", exception, err)
}
var gatewayRoute netlink.Route
var gwIP net.IP
if defaultToGateway {
gatewayRoute = eth0RouteTmpl
gwIP = defaultRoute.Gw
} else {
gatewayRoute = wgRouteTmpl
gwIP = net.ParseIP("fe80::1")
}
gatewayRoute.Dst = cidr
err = routesRunner.netlink.RouteReplace(&gatewayRoute)
if err != nil {
return fmt.Errorf("failed to add route (%s): %w", gatewayRoute, err)
}
result.Routes = append(result.Routes, &types.Route{Dst: *cidr, GW: gwIP})
}
err = addRoutingForIngress(eth0Link, *defaultRoute, sysctlDir)
if err != nil {
return err
}
return nil
}
func addRoutingForIngress(eth0Link netlink.Link, defaultRoute netlink.Route, sysctlDir string) error {
// add iptables rule to mark traffic from eth0
ipt, err := routesRunner.iptables.New()
if err != nil {
return fmt.Errorf("failed to create iptable: %w", err)
}
if err := ipt.AppendUnique(consts.MangleTable, consts.PreRoutingChain, "-i", "eth0", "-j", "MARK", "--set-mark", strconv.Itoa(consts.Eth0Mark)); err != nil {
return fmt.Errorf("failed to append iptables set-mark rule: %w", err)
}
if err := ipt.AppendUnique(consts.MangleTable, consts.PreRoutingChain, "-j", "CONNMARK", "--save-mark"); err != nil {
return fmt.Errorf("failed to append iptables save-mark rule: %w", err)
}
if err := ipt.AppendUnique(consts.MangleTable, consts.OutputChain, "-m", "connmark", "--mark", strconv.Itoa(consts.Eth0Mark), "-j", "CONNMARK", "--restore-mark"); err != nil {
return fmt.Errorf("failed to append iptables restore-mark rule: %w", err)
}
// add ip rule: lookup separate routing table if packet is marked
rule := netlink.NewRule()
rule.Mark = uint32(consts.Eth0Mark)
rule.Table = consts.Eth0Mark
if err := routesRunner.netlink.RuleAdd(rule); err != nil {
return fmt.Errorf("failed to add routing rule: %w", err)
}
// add route
defaultRoute.Table = consts.Eth0Mark
if err := routesRunner.netlink.RouteReplace(&defaultRoute); err != nil {
return fmt.Errorf("failed to add default route via eth0: %w", err)
}
// update rp_filter flag
if err := os.WriteFile(filepath.Join(sysctlDir, "net/ipv4/conf/all/rp_filter"), []byte("2"), 0644); err != nil {
return fmt.Errorf("failed to write net.ipv4.conf.all.rp_filter: %w", err)
}
if err := os.WriteFile(filepath.Join(sysctlDir, "net/ipv4/conf/eth0/rp_filter"), []byte("2"), 0644); err != nil {
return fmt.Errorf("failed to write net.ipv4.conf.eth0.rp_filter: %w", err)
}
return nil
}