tester/dns/tester.go (64 lines of code) (raw):

/* Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. */ package dns import ( "errors" "fmt" "time" "github.com/miekg/dns" "github.com/pinterest/bender" protocol "github.com/pinterest/bender/dns" ) // ExtendedMsg wraps a dns.Msg with expectations. type ExtendedMsg struct { dns.Msg Rcode int } // Tester is a load tester for DNS. type Tester struct { Target string Timeout time.Duration Protocol string client *dns.Client } // ErrInvalidRequest is an error raised when the request is invalid. var ErrInvalidRequest = errors.New("invalid request") // ErrInvalidResponse is raised when the response is invalid. var ErrInvalidResponse = errors.New("invalid response") // Before is called before the first test. func (t *Tester) Before(options interface{}) error { //nolint:exhaustivestruct t.client = &dns.Client{ ReadTimeout: t.Timeout, DialTimeout: t.Timeout, WriteTimeout: t.Timeout, Net: t.Protocol, } return nil } // After is called after all tests are finished. func (t *Tester) After(_ interface{}) {} // BeforeEach is called before every test. func (t *Tester) BeforeEach(_ interface{}) error { return nil } // AfterEach is called after every test. func (t *Tester) AfterEach(_ interface{}) {} func validator(request, response *dns.Msg) error { if request.Id != response.Id { return fmt.Errorf("%w: %d, want: %d", ErrInvalidResponse, request.Id, response.Id) } return nil } // RequestExecutor returns a request executor. func (t *Tester) RequestExecutor(options interface{}) (bender.RequestExecutor, error) { innerExecutor := protocol.CreateExecutor(t.client, validator, t.Target) return func(n int64, request interface{}) (interface{}, error) { asExtended, ok := request.(*ExtendedMsg) if !ok { return nil, fmt.Errorf("%w: invalid type, want: *ExtendedMsg, got: %T", ErrInvalidRequest, request) } resp, err := innerExecutor(n, &asExtended.Msg) if err != nil { return resp, err } asMsg, ok := resp.(*dns.Msg) if !ok { return nil, fmt.Errorf("%w: invalid type, want: *dns.Msg, got: %T", ErrInvalidResponse, resp) } if asExtended.Rcode != -1 && asExtended.Rcode != asMsg.Rcode { return resp, fmt.Errorf( "%w: invalid rcode want: %q, got: %q", ErrInvalidResponse, dns.RcodeToString[asExtended.Rcode], dns.RcodeToString[asMsg.Rcode]) } return resp, nil }, nil }