flags/distribution.go (94 lines of code) (raw):
/*
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
package flags
import (
"errors"
"fmt"
"sort"
"strings"
"github.com/facebookincubator/fbender/utils"
"github.com/pinterest/bender"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// DistributionGenerator represents distribution generator function.
type DistributionGenerator = func(float64) bender.IntervalGenerator
const (
uniformGenerator = "uniform"
exponentialGenerator = "exponential"
)
//nolint:gochecknoglobals
var generators = map[string]DistributionGenerator{
uniformGenerator: bender.UniformIntervalGenerator,
exponentialGenerator: bender.ExponentialIntervalGenerator,
}
// Distribution represents a interval generator flag value.
type Distribution struct {
Name string
generator DistributionGenerator
}
// NewDefaultDistribution returns new distribution flag with default values.
func NewDefaultDistribution() *Distribution {
return &Distribution{
Name: uniformGenerator,
generator: generators[uniformGenerator],
}
}
// ErrInvalidGenerator is raised when an unknown generator is set.
var ErrInvalidGenerator = errors.New("invalid generator")
// DistributionChoices returns a string representation of available generators.
func DistributionChoices() []string {
choices := []string{}
for key := range generators {
choices = append(choices, key)
}
sort.Strings(choices)
return choices
}
func (d *Distribution) String() string {
return d.Name
}
// Set validates a given value and sets distribution (allows prefix matching).
func (d *Distribution) Set(value string) error {
matches := []string{}
for key := range generators {
if strings.HasPrefix(key, value) {
matches = append(matches, key)
}
}
if len(matches) == 0 {
choices := ChoicesString(DistributionChoices())
return fmt.Errorf("%w, want: %s, got: %q", ErrInvalidGenerator, choices, value)
} else if len(matches) > 1 {
sort.Strings(matches)
return fmt.Errorf("%w, ambiguous prefix %q matches: %s", ErrInvalidGenerator, value, ChoicesString(matches))
}
generator := matches[0]
d.Name = generator
d.generator = generators[generator]
return nil
}
// Type returns a distribution type.
func (d *Distribution) Type() string {
return "distribution"
}
// Get returns a distibution generator.
func (d *Distribution) Get() DistributionGenerator {
return d.generator
}
// GetDistribution returns a distribution from a pflag set.
func GetDistribution(f *pflag.FlagSet, name string) (DistributionGenerator, error) {
flag := f.Lookup(name)
if flag == nil {
return nil, fmt.Errorf("%w: %q", ErrUndefined, name)
}
return GetDistributionValue(flag.Value)
}
// GetDistributionValue returns a distribution from a pflag value.
func GetDistributionValue(v pflag.Value) (DistributionGenerator, error) {
if distribution, ok := v.(*Distribution); ok {
return distribution.Get(), nil
}
return nil, fmt.Errorf("%w, want: distribution, got: %s", ErrInvalidType, v.Type())
}
// Bash completion function constants.
const (
fnameDistribution = "__fbender_handle_distribution_flag"
fbodyDistribution = `COMPREPLY=($(compgen -W "uniform exponential" -- "${cur}"))`
)
// BashCompletionDistribution adds bash completion to a distribution flag.
func BashCompletionDistribution(cmd *cobra.Command, f *pflag.FlagSet, name string) error {
flag := f.Lookup(name)
if flag == nil {
return fmt.Errorf("%w: %q", ErrUndefined, name)
}
if _, ok := flag.Value.(*Distribution); !ok {
return fmt.Errorf("%w, want: distribution, got: %s", ErrInvalidType, flag.Value.Type())
}
return utils.BashCompletion(cmd, f, name, fnameDistribution, fbodyDistribution)
}