sqlutil/scanner.go (272 lines of code) (raw):
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package sqlutil // import "go.elastic.co/apm/v2/sqlutil"
import (
"strings"
"unicode"
"unicode/utf8"
)
// Scanner is the struct used to generate SQL
// tokens for the parser.
type Scanner struct {
input string
start int // text start pos in bytes
end int // text end pos in bytes
pos int // read pos in bytes
tok Token
}
// NewScanner creates a new Scanner for sql.
func NewScanner(sql string) *Scanner {
return &Scanner{input: sql}
}
// Token returns the most recently scanned token.
func (s *Scanner) Token() Token {
return s.tok
}
// Text returns the portion of the string that relates to
// the most recently scanned token.
func (s *Scanner) Text() string {
return s.input[s.start:s.end]
}
// Scan scans for the next token and returns true if one was
// found, false if the end of the input stream was reached.
// When Scan returns true, the token type can be obtained by
// calling the Token() method, and the text can be obtained
// by calling the Text() method.
func (s *Scanner) Scan() bool {
s.tok = s.scan()
return s.tok != eof
}
func (s *Scanner) scan() Token {
r, ok := s.next()
if !ok {
return eof
}
for unicode.IsSpace(r) {
if r, ok = s.next(); !ok {
return eof
}
}
s.start = s.pos - utf8.RuneLen(r)
if r == '_' || unicode.IsLetter(r) {
return s.scanKeywordOrIdentifier(r != '_')
} else if unicode.IsDigit(r) {
return s.scanNumericLiteral()
}
switch r {
case '\'':
// Standard string literal.
return s.scanStringLiteral()
case '"':
// Standard double-quoted identifier.
//
// NOTE(axw) MySQL will treat " as a
// string literal delimiter by default,
// but we assume standard SQL and treat
// it as a identifier delimiter.
return s.scanQuotedIdentifier('"')
case '[':
// T-SQL bracket-quoted identifier.
return s.scanQuotedIdentifier(']')
case '`':
// MySQL-style backtick-quoted identifier.
return s.scanQuotedIdentifier('`')
case '(':
return LPAREN
case ')':
return RPAREN
case '-':
if next, ok := s.peek(); ok && next == '-' {
// -- comment
s.next()
return s.scanSimpleComment()
}
return OTHER
case '/':
if next, ok := s.peek(); ok {
switch next {
case '*':
// /* comment */
s.next()
return s.scanBracketedComment()
case '/':
// // comment
s.next()
return s.scanSimpleComment()
}
}
return OTHER
case '.':
return PERIOD
case '$':
next, ok := s.peek()
if !ok {
break
}
if unicode.IsDigit(next) {
// This is a variable like "$1".
for {
if next, ok := s.peek(); !ok || !unicode.IsDigit(next) {
break
}
s.next()
}
return OTHER
} else if next == '$' || next == '_' || unicode.IsLetter(next) {
// PostgreSQL supports dollar-quoted string literal syntax,
// like $foo$...$foo$. The tag (foo in this case) is optional,
// and if present follows identifier rules.
for {
r, ok := s.next()
if !ok {
// Unknown token starting with $ until
// EOF, just ignore it.
return OTHER
}
switch {
case r == '$':
// This marks the end of the initial $foo$.
tag := s.Text()
if i := strings.Index(s.input[s.pos:], tag); i >= 0 {
s.end += i + len(tag)
s.pos += i + len(tag)
return STRING
}
case unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_':
// Identifier rune, consume.
case unicode.IsSpace(r):
// Unknown token starting with $,
// consume runes until space.
s.end -= utf8.RuneLen(r)
return OTHER
}
}
}
return OTHER
}
return OTHER
}
func (s *Scanner) scanKeywordOrIdentifier(maybeKeyword bool) Token {
loop:
for {
r, ok := s.peek()
if !ok {
break loop
}
switch {
case unicode.IsLetter(r):
case unicode.IsDigit(r) || r == '_' || r == '$':
maybeKeyword = false
default:
break loop
}
s.next()
}
if !maybeKeyword {
return IDENT
}
text := s.Text()
if len(text) >= len(keywords) {
return IDENT
}
for _, token := range keywords[len(text)] {
if strings.EqualFold(text, token.String()) {
return token
}
}
return IDENT
}
func (s *Scanner) scanQuotedIdentifier(delim rune) Token {
loop:
for {
r, ok := s.next()
if !ok {
return eof
}
if r == delim {
if delim == '"' {
if r, ok := s.peek(); ok && r == delim {
// Skip escaped double quotes,
// e.g. "He said ""great""".
s.next()
continue loop
}
}
break
}
}
// Remove quotes from identifier.
s.start++
s.end--
return IDENT
}
func (s *Scanner) scanNumericLiteral() Token {
var havePeriod bool
var haveExponent bool
for {
r, ok := s.peek()
if !ok {
return NUMBER
}
if unicode.IsDigit(r) {
s.next()
continue
}
switch r {
case '.':
if havePeriod {
return NUMBER
}
s.next()
havePeriod = true
case 'e', 'E':
if haveExponent {
return NUMBER
}
s.next()
haveExponent = true
if r, ok := s.peek(); ok && (r == '+' || r == '-') {
s.next()
}
default:
return NUMBER
}
}
}
func (s *Scanner) scanStringLiteral() Token {
const delim = '\''
for {
r, ok := s.next()
if !ok {
return eof
}
if r == '\\' {
// Skip escaped character, e.g. 'what\'s up?'
s.next()
continue
}
if r != delim {
continue
}
if r, ok := s.peek(); !ok || r != delim {
return STRING
}
// Two ' characters next to each other
// are collapsed in a string literal,
// rather than escaping the string. We
// don't care about string values, so
// we don't collapse.
s.next()
}
}
func (s *Scanner) scanSimpleComment() Token {
for {
if r, ok := s.next(); !ok || r == '\n' {
return COMMENT
}
}
}
func (s *Scanner) scanBracketedComment() Token {
nesting := 1
for {
r, ok := s.next()
if !ok {
return eof
}
switch r {
case '/':
r, ok := s.peek()
if ok && r == '*' {
s.next()
nesting++
}
case '*':
r, ok := s.peek()
if ok && r == '/' {
s.next()
nesting--
if nesting == 0 {
return COMMENT
}
}
}
}
}
// next returns the next rune if there is one, and advances
// the scanner position, or returns utf8.RuneError if there
// is no valid next rune. The bool result indicates whether
// a valid rune is returned.
func (s *Scanner) next() (rune, bool) {
r, rlen := s.peekLen()
if r != utf8.RuneError {
s.pos += rlen
s.end = s.pos
return r, true
}
return r, false
}
// peek returns the next rune if there is one, or
// utf8.RuneError if not. The bool result indicates
// whether a valid rune is returned.
func (s *Scanner) peek() (rune, bool) {
r, _ := s.peekLen()
if r == utf8.RuneError {
return utf8.RuneError, false
}
return r, true
}
// peekLen returns the next rune (if there is one)
// and its length. If there is no next valid rune,
// utf8.RuneError and a length of -1 are returned.
func (s *Scanner) peekLen() (rune, int) {
if s.pos >= len(s.input) {
return utf8.RuneError, -1
}
return utf8.DecodeRuneInString(s.input[s.pos:])
}