runtime/router/trie.go (227 lines of code) (raw):
// Copyright (c) 2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package router
import (
"errors"
"fmt"
"net/http"
"strings"
)
var (
errPath = errors.New("bad path")
errExist = errors.New("path value already set")
errNotFound = errors.New("not found")
)
type paramMismatch struct {
expected, actual string
existingPath string
}
// Error returns the error string
func (e *paramMismatch) Error() string {
return fmt.Sprintf("param key mismatch: expected is %s but got %s", e.expected, e.actual)
}
// Param is a url parameter where key is the url segment pattern (without :) and
// value is the actual segment of a matched url.
// e.g. url /foo/123 matches /foo/:id, the url param has key "id" and value "123"
type Param struct {
Key, Value string
}
// Trie is a radix trie to store string value at given url path,
// a trie node corresponds to an arbitrary path substring.
type Trie struct {
root *tnode
}
type tnode struct {
key string
value http.Handler
children []*tnode
}
// NewTrie creates a new trie.
func NewTrie() *Trie {
return &Trie{
root: &tnode{
key: "",
},
}
}
// Set sets the value for given path, returns error if path already set.
// When a http.Handler is registered for a given path, a subsequent Get returns the registered
// handler if the url passed to Get call matches the set path. Match in this context could mean either
// equality (e.g. url is "/foo" and path is "/foo") or url matches path pattern, which has two forms:
// - path ends with "/*", e.g. url "/foo" and "/foo/bar" both matches path "/*"
// - path contains colon wildcard ("/:"), e.g. url "/a/b" and "/a/c" bot matches path "/a/:var"
// isWhitelisted - Used for special behavior using which different handlers can configured for paths such as /a and /:b in router
func (t *Trie) Set(path string, value http.Handler, isWhitelisted bool) error {
if path == "" || strings.Contains(path, "//") {
return errPath
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
// ignore trailing slash
path = strings.TrimSuffix(path, "/")
// validate "*"
if strings.Contains(path, "*") && !strings.HasSuffix(path, "/*") {
return errors.New("/* must be the last path segment")
}
if strings.Count(path, "*") > 1 {
return errors.New("path can not contain more than one *")
}
colonAsPattern := !isWhitelisted
err := t.root.set(path, value, false, false, colonAsPattern, isWhitelisted)
if e, ok := err.(*paramMismatch); ok {
return fmt.Errorf("path %q has a different param key %q, it should be the same key %q as in existing path %q", path, e.actual, e.expected, e.existingPath)
}
return err
}
// Get returns the http.Handler for given path, returns error if not found.
// It also returns the url params if given path contains any, e.g. if a handler is registered for
// "/:foo/bar", then calling Get with path "/xyz/bar" returns a param whose key is "foo" and value is "xyz".
// isWhitelisted - Used for special behavior using which different handlers can configured for paths such as /a and /:b in router
func (t *Trie) Get(path string, isWhitelisted bool) (http.Handler, []Param, error) {
if path == "" || strings.Contains(path, "//") {
return nil, nil, errPath
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
// ignore trailing slash
path = strings.TrimSuffix(path, "/")
colonAsPattern := isWhitelisted
return t.root.get(path, false, false, colonAsPattern, isWhitelisted)
}
// set sets the handler for given path, creates new child node if necessary
// lastKeyCharSlash tracks whether the previous key char is a '/', used to decide it is a pattern or not
// when the current key char is ':'. lastPathCharSlash tracks whether the previous path char is a '/',
// used to decide it is a pattern or not when the current path char is ':'.
func (t *tnode) set(path string, value http.Handler, lastKeyCharSlash, lastPathCharSlash, colonAsPattern, isWhitelisted bool) error {
// find the longest common prefix
var shorterLength, i int
keyLength, pathLength := len(t.key), len(path)
if keyLength > pathLength {
shorterLength = pathLength
} else {
shorterLength = keyLength
}
for i < shorterLength && t.key[i] == path[i] {
i++
}
// Find the first character that differs between "path" and this node's key, if it exists.
// If we encounter a colon wildcard, ensure that the wildcard in path matches the wildcard
// in this node's key for that segment. The segment is a colon wildcard only when the colon
// is immediately after slash, e.g. "/:foo", "/x/:y". "/a:b" is not a colon wildcard segment.
var keyMatchIdx, pathMatchIdx int
for keyMatchIdx < keyLength && pathMatchIdx < pathLength {
if t.isSetWildCardPattern(path, keyMatchIdx, pathMatchIdx, lastKeyCharSlash, lastPathCharSlash, isWhitelisted) {
keyStartIdx, pathStartIdx := keyMatchIdx, pathMatchIdx
same := t.key[keyMatchIdx] == path[pathMatchIdx]
for keyMatchIdx < keyLength && t.key[keyMatchIdx] != '/' {
keyMatchIdx++
}
for pathMatchIdx < pathLength && path[pathMatchIdx] != '/' {
pathMatchIdx++
}
if same && (keyMatchIdx-keyStartIdx) != (pathMatchIdx-pathStartIdx) {
return ¶mMismatch{
t.key[keyStartIdx:keyMatchIdx],
path[pathStartIdx:pathMatchIdx],
t.key,
}
}
} else if t.key[keyMatchIdx] == path[pathMatchIdx] {
keyMatchIdx++
pathMatchIdx++
} else {
break
}
lastKeyCharSlash = t.key[keyMatchIdx-1] == '/'
lastPathCharSlash = path[pathMatchIdx-1] == '/'
}
// If the node key is fully matched, we match the rest path with children nodes to see if a value
// already exists for the path.
if keyMatchIdx == keyLength {
for _, c := range t.children {
if _, _, err := c.get(path[pathMatchIdx:], lastKeyCharSlash, lastPathCharSlash, colonAsPattern, isWhitelisted); err == nil {
return errExist
}
}
}
// node key is longer than longest common prefix
if i < keyLength {
// key/path suffix being "*" means a conflict
if path[i:] == "*" || t.key[i:] == "*" {
return errExist
}
// split the node key, add new node with node key minus longest common prefix
split := &tnode{
key: t.key[i:],
value: t.value,
children: t.children,
}
t.key = t.key[:i]
t.value = nil
t.children = []*tnode{split}
// path is equal to longest common prefix
// set value on current node after split
if i == pathLength {
t.value = value
} else {
// path is longer than longest common prefix
// add new node with path minus longest common prefix
newNode := &tnode{
key: path[i:],
value: value,
}
t.addChildren(newNode, lastPathCharSlash)
}
}
// node key is equal to longest common prefix
if i == keyLength {
// path is equal to longest common prefix
if i == pathLength {
// node is guaranteed to have zero value,
// otherwise it would have caused errExist earlier
t.value = value
} else {
// path is longer than node key, try to recurse on node children
for _, c := range t.children {
if c.key[0] == path[i] {
lastKeyCharSlash = i > 0 && t.key[i-1] == '/'
lastPathCharSlash = i > 0 && path[i-1] == '/'
err := c.set(path[i:], value, lastKeyCharSlash, lastPathCharSlash, colonAsPattern, isWhitelisted)
if e, ok := err.(*paramMismatch); ok {
e.existingPath = t.key + e.existingPath
return e
}
return err
}
}
// no children to recurse, add node with path minus longest common path
newNode := &tnode{
key: path[i:],
value: value,
}
t.addChildren(newNode, lastPathCharSlash)
}
}
return nil
}
func (t *tnode) get(path string, lastKeyCharSlash, lastPathCharSlash, colonAsPattern, isWhitelistedPath bool) (http.Handler, []Param, error) {
keyLength, pathLength := len(t.key), len(path)
var params []Param
// find the longest matched prefix
var keyIdx, pathIdx int
for keyIdx < keyLength && pathIdx < pathLength {
if t.isGetWildCardPattern(path, keyIdx, pathIdx, lastKeyCharSlash, lastPathCharSlash, colonAsPattern, isWhitelistedPath) {
// wildcard starts - match until next slash
keyStartIdx, pathStartIdx := keyIdx+1, pathIdx
for keyIdx < keyLength && t.key[keyIdx] != '/' {
keyIdx++
}
for pathIdx < pathLength && path[pathIdx] != '/' {
pathIdx++
}
if t.key[keyStartIdx-1] == ':' {
params = append(params, Param{t.key[keyStartIdx:keyIdx], path[pathStartIdx:pathIdx]})
}
} else if t.key[keyIdx] == path[pathIdx] {
keyIdx++
pathIdx++
} else {
break
}
lastKeyCharSlash = t.key[keyIdx-1] == '/'
lastPathCharSlash = path[pathIdx-1] == '/'
}
if keyIdx < keyLength {
// path matches up to node key's second to last character,
// the last char of node key is "*" and path is no shorter than longest matched prefix
if t.key[keyIdx:] == "*" && pathIdx < pathLength {
return t.value, params, nil
}
return nil, nil, errNotFound
}
// ':' in path matches '*' in node key
if keyIdx > 0 && t.key[keyIdx-1] == '*' {
return t.value, params, nil
}
// longest matched prefix matches up to node key length and path length
if pathIdx == pathLength {
if t.value != nil {
return t.value, params, nil
}
return nil, nil, errNotFound
}
// longest matched prefix matches up to node key length but not path length
for _, c := range t.children {
if v, ps, err := c.get(path[pathIdx:], lastKeyCharSlash, lastPathCharSlash, colonAsPattern, isWhitelistedPath); err == nil {
return v, append(params, ps...), nil
}
}
return nil, nil, errNotFound
}
func (t *tnode) addChildren(child *tnode, lastPathCharSlash bool) {
if lastPathCharSlash && child.key[0] != ':' {
// Prepending if child is not a pattern of :var
t.children = append([]*tnode{child}, t.children...)
} else {
// Appending if the child is of pattern :var
t.children = append(t.children, child)
}
}
func (t *tnode) isSetWildCardPattern(path string, keyIdx, pathIdx int, lastKeyCharSlash, lastPathCharSlash, isWhitelistedPath bool) bool {
if isWhitelistedPath {
// For whitelisted paths, it will treat as wild card pattern only if key and path params are :var
return t.key[keyIdx] == ':' && lastKeyCharSlash && path[pathIdx] == ':' && lastPathCharSlash
}
// For normal paths, tt will treat as wild card pattern either if key or path params are :var
return (t.key[keyIdx] == ':' && lastKeyCharSlash) || (path[pathIdx] == ':' && lastPathCharSlash)
}
func (t *tnode) isGetWildCardPattern(path string, keyIdx, pathIdx int, lastKeyCharSlash, lastPathCharSlash, colonAsPattern, isWhitelistedPath bool) bool {
if isWhitelistedPath {
// For whitelisted paths, it will treat as wild card pattern only if
// 1. Param is the key is of type :var and
// 2. Param is the path is of type :var or colonAsPattern is true
return t.key[keyIdx] == ':' && lastKeyCharSlash && ((path[pathIdx] == ':' && lastPathCharSlash) || colonAsPattern)
}
// For normal paths, it will treat as wild card pattern only if
// 1. Param is the key is of type :var or
// 2. Param is the path is of type :var and colonAsPattern is true
return (t.key[keyIdx] == ':' && lastKeyCharSlash) || (path[pathIdx] == ':' && lastPathCharSlash && colonAsPattern)
}