module/apmsql/internal/pgutil/parser.go (156 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package pgutil import ( "fmt" "net" nurl "net/url" "os" "sort" "strconv" "strings" "unicode" "go.elastic.co/apm/module/apmsql/v2" ) const ( defaultPostgresPort = 5432 ) // ParseDSN parses the given lib/pq or pgx/v4 datasource name, which may // be either a URL or connection string. func ParseDSN(name string) apmsql.DSNInfo { if connStr, err := parseURL(name); err == nil { name = connStr } opts := make(values) opts["host"] = os.Getenv("PGHOST") opts["port"] = os.Getenv("PGPORT") if err := parseOpts(name, opts); err != nil { // pq.Open will fail with the same error, // so just return a zero value. return apmsql.DSNInfo{} } addr, port := getAddr(opts) info := apmsql.DSNInfo{ Address: addr, Port: port, Database: opts["dbname"], User: opts["user"], } if info.Database == "" { info.Database = os.Getenv("PGDATABASE") } if info.User == "" { info.User = os.Getenv("PGUSER") } return info } func getAddr(opts values) (string, int) { hostOpt := opts["host"] if hostOpt == "" { hostOpt = "localhost" } else if strings.HasPrefix(hostOpt, "/") { // We don't report Unix addresses. return "", 0 } else if n := len(hostOpt); n > 1 && hostOpt[0] == '[' && hostOpt[n-1] == ']' { hostOpt = hostOpt[1 : n-1] } port := defaultPostgresPort if portOpt := opts["port"]; portOpt != "" { if v, err := strconv.Atoi(portOpt); err == nil { port = v } } return hostOpt, port } // Code below is copied from github.com/lib/pq (see NOTICE). // parseURL no longer needs to be used by clients of this library since supplying a URL as a // connection string to sql.Open() is now supported: // // sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full") // // It remains exported here for backwards-compatibility. // // ParseURL converts a url to a connection string for driver.Open. // Example: // // "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full" // // converts to: // // "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full" // // A minimal example: // // "postgres://" // // This will be blank, causing driver.Open to use all of the defaults func parseURL(url string) (string, error) { u, err := nurl.Parse(url) if err != nil { return "", err } if u.Scheme != "postgres" && u.Scheme != "postgresql" { return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) } var kvs []string escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) accrue := func(k, v string) { if v != "" { kvs = append(kvs, k+"="+escaper.Replace(v)) } } if u.User != nil { v := u.User.Username() accrue("user", v) v, _ = u.User.Password() accrue("password", v) } if host, port, err := net.SplitHostPort(u.Host); err != nil { accrue("host", u.Host) } else { accrue("host", host) accrue("port", port) } if u.Path != "" { accrue("dbname", u.Path[1:]) } q := u.Query() for k := range q { accrue(k, q.Get(k)) } sort.Strings(kvs) // Makes testing easier (not a performance concern) return strings.Join(kvs, " "), nil } type values map[string]string // scanner implements a tokenizer for libpq-style option strings. type scanner struct { s []rune i int } // newScanner returns a new scanner initialized with the option string s. func newScanner(s string) *scanner { return &scanner{[]rune(s), 0} } // Next returns the next rune. // It returns 0, false if the end of the text has been reached. func (s *scanner) Next() (rune, bool) { if s.i >= len(s.s) { return 0, false } r := s.s[s.i] s.i++ return r, true } // SkipSpaces returns the next non-whitespace rune. // It returns 0, false if the end of the text has been reached. func (s *scanner) SkipSpaces() (rune, bool) { r, ok := s.Next() for unicode.IsSpace(r) && ok { r, ok = s.Next() } return r, ok } // parseOpts parses the options from name and adds them to the values. // // The parsing code is based on conninfo_parse from libpq's fe-connect.c func parseOpts(name string, o values) error { s := newScanner(name) for { var ( keyRunes, valRunes []rune r rune ok bool ) if r, ok = s.SkipSpaces(); !ok { break } // Scan the key for !unicode.IsSpace(r) && r != '=' { keyRunes = append(keyRunes, r) if r, ok = s.Next(); !ok { break } } // Skip any whitespace if we're not at the = yet if r != '=' { r, ok = s.SkipSpaces() } // The current character should be = if r != '=' || !ok { return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) } // Skip any whitespace after the = if r, ok = s.SkipSpaces(); !ok { // If we reach the end here, the last value is just an empty string as per libpq. o[string(keyRunes)] = "" break } if r != '\'' { for !unicode.IsSpace(r) { if r == '\\' { if r, ok = s.Next(); !ok { return fmt.Errorf(`missing character after backslash`) } } valRunes = append(valRunes, r) if r, ok = s.Next(); !ok { break } } } else { quote: for { if r, ok = s.Next(); !ok { return fmt.Errorf(`unterminated quoted string literal in connection string`) } switch r { case '\'': break quote case '\\': r, _ = s.Next() fallthrough default: valRunes = append(valRunes, r) } } } o[string(keyRunes)] = string(valRunes) } return nil }