cmd/dns/dns.go (80 lines of code) (raw):
/*
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
package dns
import (
"fmt"
"strings"
"time"
"github.com/facebookincubator/fbender/cmd/core/errors"
"github.com/facebookincubator/fbender/cmd/core/input"
"github.com/facebookincubator/fbender/cmd/core/options"
"github.com/facebookincubator/fbender/cmd/core/runner"
tester "github.com/facebookincubator/fbender/tester/dns"
"github.com/facebookincubator/fbender/utils"
"github.com/miekg/dns"
"github.com/spf13/cobra"
)
// DefaultServerPort is a default dns server port.
const DefaultServerPort = 53
func params(cmd *cobra.Command, o *options.Options) (*runner.Params, error) {
randomize, err := cmd.Flags().GetBool("randomize")
if err != nil {
//nolint:wrapcheck
return nil, err
}
protocol, err := GetProtocol(cmd.Flags(), "protocol")
if err != nil {
return nil, err
}
r, err := input.NewRequestGenerator(o.Input, inputTransformer, getModifiers(randomize)...)
if err != nil {
//nolint:wrapcheck
return nil, err
}
t := &tester.Tester{
Target: utils.WithDefaultPort(o.Target, DefaultServerPort),
Timeout: o.Timeout,
Protocol: protocol,
}
return &runner.Params{Tester: t, RequestGenerator: r}, nil
}
func inputTransformer(input string) (interface{}, error) {
var domain, typeString, rcodeString string
n, err := fmt.Sscanf(input, "%s %s %s", &domain, &typeString, &rcodeString)
if err != nil && n < 2 {
return nil, fmt.Errorf("%w, want: \"Domain QType [RCode]\", got: %q", errors.ErrInvalidFormat, input)
}
msgTyp, ok := dns.StringToType[strings.ToUpper(typeString)]
if !ok {
return nil, fmt.Errorf("%w, invalid QType: %q", errors.ErrInvalidFormat, typeString)
}
msg := new(tester.ExtendedMsg)
msg.SetQuestion(dns.Fqdn(domain), msgTyp)
msg.Rcode = -1
if n == 3 {
rcode, ok := dns.StringToRcode[rcodeString]
if !ok {
return nil, fmt.Errorf("%w, invalid RCode: %q", errors.ErrInvalidFormat, rcodeString)
}
msg.Rcode = rcode
}
return msg, nil
}
func getModifiers(randomize bool) []input.Modifier {
if randomize {
return []input.Modifier{randomPrefixModifier}
}
return []input.Modifier{}
}
const prefixLength = 16
func randomPrefixModifier(request interface{}) (interface{}, error) {
msg, ok := request.(*tester.ExtendedMsg)
if !ok {
return nil, fmt.Errorf("%w, want: *dns.ExtendedMsg, got: %T", errors.ErrInvalidType, request)
}
hex, err := utils.RandomHex(prefixLength)
if err != nil {
//nolint:wrapcheck
return nil, err
}
// Create a new message so we don't destroy the original to avoid recursive prefixing
modified := new(tester.ExtendedMsg)
domain := fmt.Sprintf("%d.%s.%s", time.Now().Unix(), hex, msg.Question[0].Name)
msgTyp := msg.Question[0].Qtype
modified.SetQuestion(dns.Fqdn(domain), msgTyp)
modified.Rcode = msg.Rcode
return modified, nil
}