lib/validator/policy_builder.go (265 lines of code) (raw):
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package validator
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/httputils" /* copybara-comment: httputils */
"github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/strutil" /* copybara-comment: strutil */
cpb "github.com/GoogleCloudPlatform/healthcare-federated-access-services/proto/common/v1" /* copybara-comment: go_proto */
pb "github.com/GoogleCloudPlatform/healthcare-federated-access-services/proto/dam/v1" /* copybara-comment: go_proto */
)
var (
byValues = map[string]bool{
"self": true,
"peer": true,
"system": true,
"so": true,
"dac": true,
}
)
// BuildPolicyValidator creates a new policy validator.
func BuildPolicyValidator(ctx context.Context, policy *pb.Policy, defs map[string]*pb.VisaType, sources map[string]*pb.TrustedSource, args map[string]string) (*Policy, error) {
allow, err := policyValidator(ctx, policy.AnyOf, defs, sources, args)
if err != nil {
return nil, err
}
return NewPolicy(allow, nil), nil
}
func policyValidator(ctx context.Context, anyOf []*cpb.ConditionSet, defs map[string]*pb.VisaType, sources map[string]*pb.TrustedSource, args map[string]string) (Validator, error) {
if len(anyOf) == 0 {
return nil, nil
}
var vor []Validator
for _, any := range anyOf {
var vand []Validator
for _, clause := range any.AllOf {
if err := validateVisaType(clause.Type, defs); err != nil {
return nil, err
}
srcs, err := expandSources(clause.Type, clause.Source, sources)
if err != nil {
return nil, err
}
vals, err := expandValues(clause.Value, args)
if err != nil {
return nil, err
}
by, err := expandBy(clause.By)
if err != nil {
return nil, err
}
v, err := NewClaimValidator(clause.Type, vals, "", srcs, by)
if err != nil {
return nil, err
}
vand = append(vand, v)
}
vor = append(vor, And(vand))
}
return Or(vor), nil
}
func expandSources(visaType string, src string, sources map[string]*pb.TrustedSource) (map[string]bool, error) {
from, err := expandField(src)
if err != nil {
return nil, err
}
out := make(map[string]bool)
for _, f := range from {
if strutil.IsURL(f) {
out[f] = true
continue
}
source, ok := sources[f]
if !ok {
return nil, fmt.Errorf("from %q name is not a valid source name", f)
}
for _, src := range source.Sources {
out[src] = true
}
}
if len(from) == 0 {
for sname, source := range sources {
incl := false
if len(source.VisaTypes) == 0 {
incl = true
} else {
for vidx, v := range source.VisaTypes {
if len(v) > 1 && v[0] == '^' {
// Regexp
re, err := regexp.Compile(v)
if err != nil {
return nil, fmt.Errorf("source %q visa %d invalid regular expression %q: %v", sname, vidx, v, err)
}
if re.Match([]byte(visaType)) {
incl = true
break
}
} else if v == visaType {
incl = true
break
}
}
}
if !incl {
continue
}
for _, src := range source.Sources {
out[src] = true
}
}
}
return out, nil
}
func expandValues(input string, args map[string]string) ([]string, error) {
vals, err := expandField(input)
if err != nil {
return nil, err
}
if args == nil {
return vals, err
}
for i, v := range vals {
out, err := strutil.ReplaceVariables(v, args)
if err != nil {
return nil, err
}
vals[i] = out
}
return vals, nil
}
func expandBy(input string) (map[string]bool, error) {
list, err := expandField(input)
if err != nil {
return nil, err
}
out := make(map[string]bool)
for _, by := range list {
if _, ok := byValues[by]; !ok {
return nil, fmt.Errorf("by %q is not supported", by)
}
out[by] = true
}
return out, nil
}
func expandField(input string) ([]string, error) {
if len(input) == 0 {
return nil, nil
}
i := strings.Index(input, ":")
if i < 0 {
return nil, fmt.Errorf("missing pattern type")
}
prefix := input[:i]
suffix := input[i+1:]
if len(suffix) == 0 {
return nil, fmt.Errorf("empty suffix")
}
// TODO: change this when using the new policy engine
switch prefix {
case "const":
return []string{suffix}, nil
case "pattern":
return []string{toPattern(suffix)}, nil
case "split_pattern":
sp := strings.Split(suffix, ";")
for i, s := range sp {
sp[i] = toPattern(s)
}
return sp, nil
}
return nil, fmt.Errorf("pattern type %q not supported", prefix)
}
// TODO: remove this helper function
func toPattern(input string) string {
if !strings.Contains(input, "*") && !strings.Contains(input, "?") {
return input
}
all := regexp.QuoteMeta("*")
any := regexp.QuoteMeta("?")
q := regexp.QuoteMeta(input)
q = strings.ReplaceAll(q, all, ".*")
q = strings.ReplaceAll(q, any, ".")
q = "^" + q + "$"
return q
}
func validateVisaType(typ string, defs map[string]*pb.VisaType) error {
if _, ok := defs[typ]; !ok {
return fmt.Errorf("visa type %q is undefined", typ)
}
return nil
}
// ValidatePolicy does basic validation for a policy and (optionally) the variable "args" that a policy instantiation uses.
func ValidatePolicy(policy *pb.Policy, defs map[string]*pb.VisaType, sources map[string]*pb.TrustedSource, args map[string]string) (string, error) {
usedArgs := make(map[string]bool)
valArgs := args
if valArgs == nil {
// To allow variable substitution to be attempted, set up variables to substitute based on definitions (regex match not required).
valArgs = make(map[string]string)
for v := range policy.VariableDefinitions {
valArgs[v] = "a"
}
}
for i, any := range policy.AnyOf {
for j, clause := range any.AllOf {
if err := validateVisaType(clause.Type, defs); err != nil {
return httputils.StatusPath("anyOf", strconv.Itoa(i), "allOf", strconv.Itoa(j), "type"), err
}
if _, err := expandSources(clause.Type, clause.Source, sources); err != nil {
return httputils.StatusPath("anyOf", strconv.Itoa(i), "allOf", strconv.Itoa(j), "source"), err
}
if _, err := expandValues(clause.Value, valArgs); err != nil {
return httputils.StatusPath("anyOf", strconv.Itoa(i), "allOf", strconv.Itoa(j), "value"), err
}
valArgs, err := strutil.ExtractVariables(clause.Value)
if err != nil {
return httputils.StatusPath("anyOf", strconv.Itoa(i), "allOf", strconv.Itoa(j), "value"), err
}
for arg := range valArgs {
usedArgs[arg] = true
}
if _, err := expandBy(clause.By); err != nil {
return httputils.StatusPath("anyOf", strconv.Itoa(i), "allOf", strconv.Itoa(j), "by"), err
}
}
}
for name, v := range policy.VariableDefinitions {
if len(v.Regexp) == 0 {
return httputils.StatusPath("variableDefinitions", name, "regexp"), fmt.Errorf("regular expression not specified")
}
re, err := regexp.Compile(v.Regexp)
if err != nil {
return httputils.StatusPath("variableDefinitions", name, "regexp"), fmt.Errorf("invalid regular expression: %v", err)
}
if args != nil {
arg, ok := args[name]
if !ok {
return httputils.StatusPath("variableDefinitions", name), fmt.Errorf("variable not provided")
}
if !re.Match([]byte(arg)) {
return httputils.StatusPath("variableDefinitions", name), fmt.Errorf("variable value %q invalid format", arg)
}
}
if v.Ui == nil || v.Ui["description"] == "" {
return httputils.StatusPath("variableDefinitions", name, "ui", "description"), fmt.Errorf("description not provided")
}
}
prefix := "variableDefinitions"
if args != nil {
prefix = "vars"
}
for arg := range usedArgs {
if len(policy.VariableDefinitions) == 0 {
return httputils.StatusPath(prefix, arg), fmt.Errorf("policy does not use variables")
}
if _, ok := policy.VariableDefinitions[arg]; !ok {
return httputils.StatusPath(prefix, arg), fmt.Errorf("undefined variable")
}
if args != nil {
if _, ok := args[arg]; !ok {
return httputils.StatusPath(prefix, arg), fmt.Errorf("undefined variable")
}
}
}
for arg := range args {
if _, ok := usedArgs[arg]; !ok {
return httputils.StatusPath(prefix, arg), fmt.Errorf("unused variable")
}
}
return "", nil
}