internal/network/network.go (119 lines of code) (raw):
// Copyright (c) 2016 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package network
import (
"net"
"strings"
"github.com/uber/arachne/internal/log"
"github.com/miekg/dns"
"github.com/pkg/errors"
"go.uber.org/zap"
)
// Family returns the string equivalent of the address family provided.
func Family(a *net.IP) string {
if a == nil || len(*a) <= net.IPv4len {
return "ip4"
}
if a.To4() != nil {
return "ip4"
}
return "ip6"
}
// GetSourceAddr discovers the source address.
func GetSourceAddr(
af string,
srcAddr string,
hostname string,
ifaceName string,
logger *log.Logger,
) (*net.IP, error) {
//TODO => resolve if both interface name and source address are specified and they do not match
// Source address is specified
if srcAddr != "" {
return resolveHost(af, hostname, logger)
}
// Interface name is specified
if ifaceName != "" {
return interfaceAddress(af, ifaceName)
}
return anyInterfaceAddress(af)
}
// Resolve given domain hostname/address in the given address family.
//TODO replace with net.LookupHost?
func resolveHost(af string, hostname string, logger *log.Logger) (*net.IP, error) {
addr, err := net.ResolveIPAddr(af, hostname)
if err != nil {
logger.Warn("failed to DNS resolve hostname with default server",
zap.String("hostname", hostname),
zap.Error(err))
return nil, err
}
return &addr.IP, nil
}
// ResolveIP returns DNS name of given IP address. Returns the same input string, if resolution fails.
func ResolveIP(ip string, servers []net.IP, logger *log.Logger) (string, error) {
names, err := net.LookupAddr(ip)
if err != nil || len(names) == 0 {
logger.Warn("failed to DNS resolve IP with default server",
zap.String("ip", ip),
zap.Error(err))
return resolveIPwServer(ip, servers, logger)
}
return names[0], nil
}
func resolveIPwServer(ip string, servers []net.IP, logger *log.Logger) (string, error) {
if servers == nil {
return "", errors.New("no alternate DNS servers configured")
}
c := dns.Client{}
m := dns.Msg{}
fqdn, err := dns.ReverseAddr(ip)
if err != nil {
return "", err
}
m.SetQuestion(fqdn, dns.TypePTR)
for _, s := range servers {
r, t, err := c.Exchange(&m, s.String()+":53")
if err != nil || len(r.Answer) == 0 {
continue
}
logger.Debug("Reverse DNS resolution for ip with user-configured DNS server took",
zap.String("ip", ip),
zap.Float64("duration", t.Seconds()))
resolved := strings.Split(r.Answer[0].String(), "\t")
// return fourth tab-delimited field of DNS query response
return resolved[4], nil
}
logger.Warn("failed to DNS resolve IP with alternate servers", zap.String("ip", ip))
return "", errors.Errorf("failed to DNS resolve %s with alternate servers", ip)
}
func interfaceAddress(af string, name string) (*net.IP, error) {
iface, err := net.InterfaceByName(name)
if err != nil {
return nil, errors.Wrapf(err, "net.InterfaceByName for %s", name)
}
addrs, err := iface.Addrs()
if err != nil {
return nil, errors.Wrap(err, "iface.Addrs")
}
return findAddrInRange(af, addrs)
}
func anyInterfaceAddress(af string) (*net.IP, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, errors.Wrap(err, "net.Interfaces")
}
for _, iface := range interfaces {
// Skip loopback
if (iface.Flags & net.FlagLoopback) == net.FlagLoopback {
continue
}
addrs, err := iface.Addrs()
// Skip if error getting addresses
if err != nil {
return nil, errors.Wrapf(err, "error getting addresses for interface %s", iface.Name)
}
if len(addrs) > 0 {
return interfaceAddress(af, iface.Name)
}
}
return nil, err
}
func findAddrInRange(af string, addrs []net.Addr) (*net.IP, error) {
for _, a := range addrs {
ipnet, ok := a.(*net.IPNet)
if ok && !(ipnet.IP.IsLoopback() || ipnet.IP.IsMulticast() || ipnet.IP.IsLinkLocalUnicast()) {
if (ipnet.IP.To4() != nil && af == "ip4") || (ipnet.IP.To4() == nil && af == "ip6") {
return &ipnet.IP, nil
}
}
}
return nil, errors.Errorf("could not find a source address in %s address family", af)
}