lib/limit.go (359 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package lib import ( "fmt" "net/http" "net/url" "strconv" "strings" "time" "github.com/google/cel-go/cel" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" "golang.org/x/time/rate" ) // Limit returns a cel.EnvOption to configure extended functions for interpreting // request rate limit policies. // // It takes a mapping of policy names to policy interpreters to allow implementing // specific rate limit policies. The map returned by the policy functions should // have "rate" and "next" fields with type rate.Limit or string with the value "inf", // a "burst" field with type int and a "reset" field with type time.Time in the UTC // location. The semantics of "rate" and "burst" are described in the documentation // for the golang.org/x/time/rate package. // // The map may have other fields that can be logged. If a field named "error" // exists it should be a string with an error message indicating the result can // not be used. // // # Rate Limit // // rate_limit returns <map<string,dyn>> interpreted through the registered rate // limit policy or with a generalised policy constructor: // // rate_limit(<map<string,dyn>>, <string>, <duration>) -> <map<string,dyn>> // rate_limit(<map<string,dyn>>, <string>, <bool>, <bool>, <duration>, <int>) -> <map<string,dyn>> // // In the first form the string is the policy name and the duration is the default // quota window to use in the absence of window information from headers. // // In the second form the parameters are the header, the prefix for the rate limit // header keys, whether the keys are canonically formatted MIME header keys, // whether the reset header is a delta as opposed to a timestamp, the duration // of the quota window, and the burst rate. rate_limit in the second form will // never set a burst rate to zero. // // In all cases if any of the three rate limit headers is missing the rate_limit // call returns a map with only the headers written. This should be considered an // error condition. // // Examples: // // rate_limit(h, 'okta', duration('1m')) // rate_limit(h, 'draft', duration('1m')) // // // Similar semantics to the okta policy. // rate_limit(h, 'X-Rate-Limit', true, false, duration('1s'), 1) // // // Similar semantics to the draft policy in the simplest case. // rate_limit(h, 'Rate-Limit', true, true, duration('1s'), 1) // // // Non-canonical keys. // rate_limit(h, 'X-RateLimit', false, false, duration('1s'), 1) func Limit(policy map[string]LimitPolicy) cel.EnvOption { return cel.Lib(limitLib{policies: policy}) } type LimitPolicy func(header http.Header, window time.Duration) map[string]interface{} type limitLib struct { policies map[string]LimitPolicy } func (l limitLib) CompileOptions() []cel.EnvOption { return []cel.EnvOption{ cel.Function("rate_limit", cel.Overload( "map_dyn_rate_limit_string_duration", []*cel.Type{mapStringDyn, cel.StringType, cel.DurationType}, mapStringDyn, cel.FunctionBinding(catch(l.translatePolicy)), ), cel.Overload( "map_dyn_rate_limit_string_bool_bool_duration_int", []*cel.Type{mapStringDyn, cel.StringType, cel.BoolType, cel.BoolType, cel.DurationType, cel.IntType}, mapStringDyn, cel.FunctionBinding(catch(translatePolicy)), ), ), } } func (limitLib) ProgramOptions() []cel.ProgramOption { return nil } func (l limitLib) translatePolicy(args ...ref.Val) ref.Val { if len(args) != 3 { return types.NewErr("no such overload") } headers, ok := args[0].(traits.Mapper) if !ok { return types.ValOrErr(headers, "no such overload for headers: %s", args[0].Type()) } policy, ok := args[1].(types.String) if !ok { return types.ValOrErr(policy, "no such overload for policy: %s", args[1].Type()) } window, ok := args[2].(types.Duration) if !ok { return types.ValOrErr(window, "no such overload for window: %s", args[2].Type()) } translate, ok := l.policies[string(policy)] if !ok { return types.NewErr("unknown policy: %q", policy) } if translate == nil { return types.NewErr("policy is nil: %q", policy) } h, err := mapStrings(headers) if err != nil { return types.NewErr("%s", err) } return types.DefaultTypeAdapter.NativeToValue(translate(h, window.Duration)) } func mapStrings(val ref.Val) (map[string][]string, error) { iface := val.Value() switch iface := iface.(type) { case http.Header: return iface, nil case url.Values: return iface, nil case map[string][]string: return iface, nil case map[ref.Val]ref.Val: val := types.DefaultTypeAdapter.NativeToValue(iface) v, err := val.ConvertToNative(reflectMapStringStringSliceType) if err != nil { return nil, err } return v.(map[string][]string), nil case ref.Val: v, err := iface.ConvertToNative(reflectMapStringStringSliceType) if err != nil { return nil, err } return v.(map[string][]string), nil default: return nil, fmt.Errorf("invalid type: %T", iface) } } // OktaRateLimit implements the Okta rate limit policy translation. // It should be handed to the Limit lib with // // Limit(map[string]lib.LimitPolicy{ // "okta": lib.OktaRateLimit, // }) // // It will then be able to be used in a limit call with the window duration // given by the Okta documentation. // // Example: // // rate_limit(h, 'okta', duration('1m')) // // might return: // // { // "burst": 1, // "headers": "X-Rate-Limit-Limit=\"600\" X-Rate-Limit-Remaining=\"598\" X-Rate-Limit-Reset=\"1650094960\"", // "next": 10, // "rate": 0.9975873271836141, // "reset": "2022-04-16T07:48:40Z" // }, // // See https://developer.okta.com/docs/reference/rl-best-practices/ func OktaRateLimit(h http.Header, window time.Duration) map[string]interface{} { limit := h.Get("X-Rate-Limit-Limit") remaining := h.Get("X-Rate-Limit-Remaining") reset := h.Get("X-Rate-Limit-Reset") if limit == "" || remaining == "" || reset == "" { return map[string]interface{}{ "headers": fmt.Sprintf("X-Rate-Limit-Limit=%q X-Rate-Limit-Remaining=%q X-Rate-Limit-Reset=%q", limit, remaining, reset), } } lim, err := strconv.ParseFloat(limit, 64) if err != nil { return map[string]interface{}{ "headers": fmt.Sprintf("X-Rate-Limit-Limit=%q X-Rate-Limit-Remaining=%q X-Rate-Limit-Reset=%q", limit, remaining, reset), "error": err.Error(), } } rem, err := strconv.ParseFloat(remaining, 64) if err != nil { return map[string]interface{}{ "headers": fmt.Sprintf("X-Rate-Limit-Limit=%q X-Rate-Limit-Remaining=%q X-Rate-Limit-Reset=%q", limit, remaining, reset), "error": err.Error(), } } rst, err := strconv.ParseInt(reset, 10, 64) if err != nil { return map[string]interface{}{ "headers": fmt.Sprintf("X-Rate-Limit-Limit=%q X-Rate-Limit-Remaining=%q X-Rate-Limit-Reset=%q", limit, remaining, reset), "error": err.Error(), } } resetTime := time.Unix(rst, 0) per := time.Until(resetTime).Seconds() return map[string]interface{}{ "headers": fmt.Sprintf("X-Rate-Limit-Limit=%q X-Rate-Limit-Remaining=%q X-Rate-Limit-Reset=%q", limit, remaining, reset), "rate": rate.Limit(rem / per), "next": rate.Limit(lim / window.Seconds()), "burst": 1, // Be conservative here; the docs don't exactly specify burst rates. "reset": resetTime.UTC(), } } // DraftRateLimit implements the draft rate limit policy translation. // It should be handed to the Limit lib with // // Limit(map[string]lib.LimitPolicy{ // "draft": lib.DraftRateLimit, // }) // // It will then be able to be used in a limit call where the duration is // the default quota window. // // Example: // // rate_limit(h, 'draft', duration('60s')) // // might return something like: // // { // "burst": 1, // "headers": "Rate-Limit-Limit=\"5000\" Rate-Limit-Remaining=\"100\" Rate-Limit-Reset=\"Sat, 16 Apr 2022 07:48:40 GMT\"", // "next": 83.33333333333333, // "rate": 0.16689431007474315, // "reset": "2022-04-16T07:48:40Z" // } // // or // // { // "burst": 1000, // "headers": "Rate-Limit-Limit=\"12, 12;window=1; burst=1000;policy=\\\"leaky bucket\\\"\" Rate-Limit-Remaining=\"100\" Rate-Limit-Reset=\"Sat, 16 Apr 2022 07:48:40 GMT\"", // "next": 12, // "rate": 100, // "reset": "2022-04-16T07:48:40Z" // } // // See https://datatracker.ietf.org/doc/html/draft-polli-ratelimit-headers-00 func DraftRateLimit(h http.Header, window time.Duration) map[string]interface{} { limit := h.Get("Rate-Limit-Limit") remaining := h.Get("Rate-Limit-Remaining") reset := h.Get("Rate-Limit-Reset") if limit == "" || remaining == "" || reset == "" { return map[string]interface{}{ "headers": fmt.Sprintf("Rate-Limit-Limit=%q Rate-Limit-Remaining=%q Rate-Limit-Reset=%q", limit, remaining, reset), } } rem, err := strconv.ParseFloat(remaining, 64) if err != nil { return map[string]interface{}{ "headers": fmt.Sprintf("Rate-Limit-Limit=%q Rate-Limit-Remaining=%q Rate-Limit-Reset=%q", limit, remaining, reset), "error": err.Error(), } } var ( per float64 resetTime time.Time ) if d, err := strconv.ParseFloat(reset, 64); err == nil { per = d resetTime = time.Now().Add(time.Duration(d) * time.Second) } else if t, err := time.Parse(http.TimeFormat, reset); err == nil { per = time.Until(t).Seconds() resetTime = t } else if t, err := time.Parse(time.RFC1123, reset); err == nil { per = time.Until(t).Seconds() resetTime = t } else { return map[string]interface{}{ "headers": fmt.Sprintf("Rate-Limit-Limit=%q Rate-Limit-Remaining=%q Rate-Limit-Reset=%q", limit, remaining, reset), "error": fmt.Sprintf("could not parse %q as number or timestamp", reset), } } burst := 1 // Examine quota policies. limFields := strings.Split(limit, ",") quota, err := strconv.Atoi(limFields[0]) if err != nil { return map[string]interface{}{ "headers": fmt.Sprintf("Rate-Limit-Limit=%q Rate-Limit-Remaining=%q Rate-Limit-Reset=%q", limit, remaining, reset), "error": err.Error(), } } win := window.Seconds() for _, f := range limFields[1:] { p := policy(strings.TrimSpace(f)) q, err := p.quota() if err != nil { return map[string]interface{}{ "headers": fmt.Sprintf("Rate-Limit-Limit=%q Rate-Limit-Remaining=%q Rate-Limit-Reset=%q", limit, remaining, reset), "error": err.Error(), } } if q > quota { break } w, b, err := p.details(q) if err != nil { return map[string]interface{}{ "headers": fmt.Sprintf("Rate-Limit-Limit=%q Rate-Limit-Remaining=%q Rate-Limit-Reset=%q", limit, remaining, reset), "error": err.Error(), } } if w >= 0 { win = float64(w) } if b > 0 { burst = b } } return map[string]interface{}{ "headers": fmt.Sprintf("Rate-Limit-Limit=%q Rate-Limit-Remaining=%q Rate-Limit-Reset=%q", limit, remaining, reset), "rate": rate.Limit(rem / per), "next": rate.Limit(float64(quota) / win), "burst": burst, "reset": resetTime.UTC(), } } type policy string func (p policy) quota() (int, error) { idx := strings.Index(string(p), ";") if idx < 0 { return 0, fmt.Errorf("invalid policy: %q", p) } return strconv.Atoi(string(p[:idx])) } func (p policy) details(q int) (window, burst int, err error) { window = -1 burst = -1 for _, f := range strings.Split(string(p), ";") { f := strings.TrimSpace(f) switch { case strings.HasPrefix(f, "window="): window, err = strconv.Atoi(strings.TrimPrefix(f, "window=")) if err != nil { return window, burst, err } case strings.HasPrefix(f, "burst="): burst, err = strconv.Atoi(strings.TrimPrefix(f, "burst=")) if err != nil { return window, burst, err } } } return window, burst, nil } func translatePolicy(args ...ref.Val) ref.Val { if len(args) != 6 { return types.NewErr("no such overload") } headers, ok := args[0].(traits.Mapper) if !ok { return types.ValOrErr(headers, "no such overload for headers: %s", args[0].Type()) } h, err := mapStrings(headers) if err != nil { return types.NewErr("%s", err) } prefix, ok := args[1].(types.String) if !ok { return types.ValOrErr(prefix, "no such overload for prefix: %s", args[1].Type()) } canonical, ok := args[2].(types.Bool) if !ok { return types.ValOrErr(canonical, "no such overload for canonical: %s", args[1].Type()) } delta, ok := args[3].(types.Bool) if !ok { return types.ValOrErr(delta, "no such overload for delta: %s", args[2].Type()) } window, ok := args[4].(types.Duration) if !ok { return types.ValOrErr(window, "no such overload for window: %s", args[3].Type()) } burst, ok := args[5].(types.Int) if !ok { return types.ValOrErr(burst, "no such overload for burst: %s", args[4].Type()) } p := limitPolicy(h, string(prefix), bool(canonical), bool(delta), window.Duration, int(burst)) return types.DefaultTypeAdapter.NativeToValue(p) } func limitPolicy(h http.Header, prefix string, canonical, delta bool, window time.Duration, burst int) map[string]interface{} { get := getNonCanonical if canonical { get = http.Header.Get } limitKey := prefix + "-Limit" limit := get(h, limitKey) remainingKey := prefix + "-Remaining" remaining := get(h, remainingKey) resetKey := prefix + "-Reset" reset := get(h, resetKey) m := map[string]interface{}{ "headers": fmt.Sprintf("%s=%q %s=%q %s=%q", limitKey, limit, remainingKey, remaining, resetKey, reset), } if limit == "" || remaining == "" || reset == "" { return m } lim, err := strconv.ParseFloat(limit, 64) if err != nil { m["error"] = err.Error() return m } rem, err := strconv.ParseFloat(remaining, 64) if err != nil { m["error"] = err.Error() return m } var ( per float64 resetTime time.Time ) if d, err := strconv.ParseInt(reset, 10, 64); err == nil { if delta { per = float64(d) resetTime = time.Now().Add(time.Duration(d) * time.Second) } else { resetTime = time.Unix(d, 0) per = time.Until(resetTime).Seconds() } } else if t, err := time.Parse(http.TimeFormat, reset); err == nil { per = time.Until(t).Seconds() resetTime = t } else if t, err := time.Parse(time.RFC1123, reset); err == nil { per = time.Until(t).Seconds() resetTime = t } else { m["error"] = fmt.Sprintf("could not parse %q as number or timestamp", reset) return m } per *= window.Seconds() m["next"] = rate.Limit(lim / window.Seconds()) m["rate"] = rate.Limit(rem / per) if burst < 1 { burst = 1 } m["burst"] = burst m["reset"] = resetTime.UTC() return m } func getNonCanonical(h http.Header, k string) string { if h == nil { return "" } v := h[k] if len(v) == 0 { return "" } return v[0] }