tester/dhcpv4/tester.go (67 lines of code) (raw):
/*
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
package dhcpv4
import (
"errors"
"fmt"
"net"
"time"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/async"
"github.com/pinterest/bender"
protocol "github.com/pinterest/bender/dhcpv4"
)
// Tester is a load tester for DHCPv4.
type Tester struct {
Target string
Timeout time.Duration
BufferSize int
client *async.Client
}
// Before is called before the first test.
func (t *Tester) Before(options interface{}) error {
target, err := net.ResolveUDPAddr("udp4", t.Target)
if err != nil {
return fmt.Errorf("unable to set up the tester: %w", err)
}
addr, err := getLocalIPv4("eth0")
if err != nil {
return fmt.Errorf("unable to set up the tester: %w", err)
}
t.client = &async.Client{
ReadTimeout: t.Timeout,
WriteTimeout: t.Timeout,
RemoteAddr: target,
LocalAddr: &net.UDPAddr{IP: addr, Port: async.DefaultServerPort},
IgnoreErrors: true,
}
return nil
}
// After is called after all tests are finished.
func (t *Tester) After(_ interface{}) {}
// BeforeEach is called before every test.
func (t *Tester) BeforeEach(options interface{}) error {
return t.client.Open(t.BufferSize)
}
// AfterEach is called after every test.
func (t *Tester) AfterEach(_ interface{}) {
t.client.Close()
}
func validator(req, res *dhcpv4.DHCPv4) error {
return nil
}
// RequestExecutor returns a request executor.
func (t *Tester) RequestExecutor(_ interface{}) (bender.RequestExecutor, error) {
return protocol.CreateExecutor(t.client, validator)
}
// ErrNoAddress is raised when an interface has no ipv4 addresses assigned.
var ErrNoAddress = errors.New("no ipv4 address found")
// getLocalIPv4 returns the interface local IPv4 address.
func getLocalIPv4(ifname string) (net.IP, error) {
iface, err := net.InterfaceByName(ifname)
if err != nil {
//nolint:wrapcheck
return nil, err
}
ifaddrs, err := iface.Addrs()
if err != nil {
//nolint:wrapcheck
return nil, err
}
for _, ifaddr := range ifaddrs {
if ipnet, ok := ifaddr.(*net.IPNet); ok {
if ipnet.IP.To4() != nil && !ipnet.IP.IsLoopback() && !ipnet.IP.IsLinkLocalUnicast() {
return ipnet.IP, nil
}
}
}
return nil, fmt.Errorf("%w, interface: %s", ErrNoAddress, ifname)
}