flags/constraint.go (69 lines of code) (raw):

/* Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. */ package flags import ( "encoding/csv" "fmt" "strings" "github.com/facebookincubator/fbender/tester" "github.com/spf13/pflag" ) // ConstraintSliceValue is a pflag value storing constraints. type ConstraintSliceValue struct { Parsers []tester.MetricParser value *[]*tester.Constraint changed bool } // NewConstraintSliceValue creates a new constraint slice value for pflag. func NewConstraintSliceValue(parsers ...tester.MetricParser) *ConstraintSliceValue { v := []*tester.Constraint{} return &ConstraintSliceValue{ Parsers: parsers, value: &v, changed: false, } } func readAsCSV(val string) ([]string, error) { if val == "" { return []string{}, nil } stringReader := strings.NewReader(val) csvReader := csv.NewReader(stringReader) return csvReader.Read() } // Set validates given string given constraints and parses them to constraint // structures using metric parsers. func (c *ConstraintSliceValue) Set(value string) error { values, err := readAsCSV(value) if err != nil { return err } constraints := []*tester.Constraint{} for _, v := range values { constraint, err := tester.ParseConstraint(v, c.Parsers...) if err != nil { return fmt.Errorf("error parsing constraint %q: %w", v, err) } constraints = append(constraints, constraint) } if !c.changed { *c.value = constraints } else { *c.value = append(*c.value, constraints...) } c.changed = true return nil } // Type returns the ConstraintSliceValue Type. func (c *ConstraintSliceValue) Type() string { return "constraints" } func (c *ConstraintSliceValue) String() string { return fmt.Sprintf("%+v", *c.value) } // GetConstraints returns a constraints from a pflag set. func GetConstraints(f *pflag.FlagSet, name string) ([]*tester.Constraint, error) { flag := f.Lookup(name) if flag == nil { return nil, fmt.Errorf("%w: %q", ErrUndefined, name) } return GetConstraintsValue(flag.Value) } // GetConstraintsValue returns a constraints from a pflag value. func GetConstraintsValue(v pflag.Value) ([]*tester.Constraint, error) { if constraints, ok := v.(*ConstraintSliceValue); ok { return *constraints.value, nil } return nil, fmt.Errorf("%w, want: constraints, got: %s", ErrInvalidType, v.Type()) }