json/visitor.go (359 lines of code) (raw):
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package json
import (
"fmt"
"io"
"math"
"strconv"
"unicode/utf8"
structform "github.com/elastic/go-structform"
)
// Visitor implements the structform.Visitor interface, json encoding the
// structure being visited
type Visitor struct {
w writer
escapeSet []bool
first boolStack
inArray boolStack
scratch [64]byte
ignoreInvalidFloat bool
explicitRadixPoint bool
}
type boolStack struct {
stack []bool
current bool
}
var _ structform.Visitor = &Visitor{}
var htmlEscapeSet = [utf8.RuneSelf]bool{}
var jsonEscapeSet = [utf8.RuneSelf]bool{}
type writer struct {
out io.Writer
}
func init() {
// control characters must be escaped
for i := 0; i < 32; i++ {
htmlEscapeSet[i] = true
jsonEscapeSet[i] = true
}
// json string required escaping
for _, c := range "\"\\" {
htmlEscapeSet[c] = true
jsonEscapeSet[c] = true
}
// html escaping
for _, c := range "&<>" {
htmlEscapeSet[c] = true
}
}
func (w writer) write(b []byte) error {
_, err := w.out.Write(b)
return err
}
func NewVisitor(out io.Writer) *Visitor {
v := &Visitor{w: writer{out}, escapeSet: htmlEscapeSet[:]}
return v
}
func (v *Visitor) SetEscapeHTML(b bool) {
if b {
v.escapeSet = htmlEscapeSet[:]
} else {
v.escapeSet = jsonEscapeSet[:]
}
}
// SetIgnoreInvalidFloat configures how the visitor handles undefined floating point values like NaN or Inf.
// By default the visitor will error. This behavior is similar to setting SetIgnoreInvalidFloat(false).
// If true is passed, then invalid floating point values will be replaces with the `null` symbol.
func (v *Visitor) SetIgnoreInvalidFloat(b bool) {
v.ignoreInvalidFloat = b
}
// SetExplicitRadixPoint configures whether the visitor encodes floating point values with an explicit radix point.
// By default, equiv to SetExplicitRadixPoint(false), the radix point will be skipped if it is not needed.
// e.g. 1.0 to 1 instead of 1.0, 100000000 to 1e+8 instead of 1.0e+8.
// If true is passed, the encoded number will always contain a radix point,
// in either decimal form or scientific notation.
// This may be useful to signal the type of the number to a json parser.
func (v *Visitor) SetExplicitRadixPoint(b bool) {
v.explicitRadixPoint = b
}
func (vs *Visitor) writeByte(b byte) error {
vs.scratch[0] = b
return vs.w.write(vs.scratch[:1])
}
func (vs *Visitor) writeString(s string) error {
return vs.w.write(str2Bytes(s))
}
func (vs *Visitor) OnObjectStart(_ int, _ structform.BaseType) error {
if err := vs.tryElemNext(); err != nil {
return err
}
vs.first.push(true)
vs.inArray.push(false)
return vs.writeByte('{')
}
func (vs *Visitor) OnObjectFinished() error {
vs.first.pop()
vs.inArray.pop()
return vs.writeByte('}')
}
func (vs *Visitor) OnKeyRef(s []byte) error {
if err := vs.onFieldNext(); err != nil {
return err
}
err := vs.OnStringRef(s)
if err == nil {
err = vs.writeByte(':')
}
return err
}
func (vs *Visitor) OnKey(s string) error {
if err := vs.onFieldNext(); err != nil {
return err
}
err := vs.OnString(s)
if err == nil {
err = vs.writeByte(':')
}
return err
}
func (vs *Visitor) onFieldNext() error {
if vs.first.current {
vs.first.current = false
return nil
}
return vs.writeByte(',')
}
func (vs *Visitor) OnArrayStart(_ int, _ structform.BaseType) error {
if err := vs.tryElemNext(); err != nil {
return err
}
vs.first.push(true)
vs.inArray.push(true)
return vs.writeByte('[')
}
func (vs *Visitor) OnArrayFinished() error {
vs.first.pop()
vs.inArray.pop()
return vs.writeByte(']')
}
func (vs *Visitor) tryElemNext() error {
if !vs.inArray.current {
return nil
}
if vs.first.current {
vs.first.current = false
return nil
}
return vs.w.write(commaSymbol)
}
var hex = "0123456789abcdef"
func (vs *Visitor) OnStringRef(s []byte) error {
return vs.OnString(bytes2Str(s))
}
func (vs *Visitor) OnString(s string) error {
if err := vs.tryElemNext(); err != nil {
return err
}
escapeSet := vs.escapeSet
vs.writeByte('"')
start := 0
for i := 0; i < len(s); {
if b := s[i]; b < utf8.RuneSelf {
if !escapeSet[b] {
i++
continue
}
if start < i {
vs.writeString(s[start:i])
}
switch b {
case '\\', '"':
vs.scratch[0], vs.scratch[1] = '\\', b
vs.w.write(vs.scratch[:2])
case '\n':
vs.scratch[0], vs.scratch[1] = '\\', 'n'
vs.w.write(vs.scratch[:2])
case '\r':
vs.scratch[0], vs.scratch[1] = '\\', 'r'
vs.w.write(vs.scratch[:2])
case '\t':
vs.scratch[0], vs.scratch[1] = '\\', 't'
vs.w.write(vs.scratch[:2])
default:
// This vsodes bytes < 0x20 except for \n and \r,
// as well as <, > and &. The latter are escaped because they
// can lead to security holes when user-controlled strings
// are rendered into JSON and served to some browsers.
vs.scratch[0], vs.scratch[1], vs.scratch[2], vs.scratch[3] = '\\', 'u', '0', '0'
vs.scratch[4] = hex[b>>4]
vs.scratch[5] = hex[b&0xF]
vs.w.write(vs.scratch[:6])
}
i++
start = i
continue
}
c, size := utf8.DecodeRuneInString(s[i:])
if c == utf8.RuneError && size == 1 {
if start < i {
vs.writeString(s[start:i])
}
vs.w.write(invalidCharSym)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
if c == '\u2028' || c == '\u2029' {
if start < i {
vs.writeString(s[start:i])
}
vs.writeString(`\u202`)
vs.writeByte(hex[c&0xF])
i += size
start = i
continue
}
i += size
}
if start < len(s) {
vs.writeString(s[start:])
}
vs.writeByte('"')
return nil
}
func (vs *Visitor) OnBool(b bool) error {
if err := vs.tryElemNext(); err != nil {
return err
}
var err error
if b {
err = vs.w.write(trueSymbol)
} else {
err = vs.w.write(falseSymbol)
}
return err
}
func (vs *Visitor) OnNil() error {
if err := vs.tryElemNext(); err != nil {
return err
}
err := vs.w.write(nullSymbol)
return err
}
func (vs *Visitor) OnInt8(i int8) error {
return vs.onInt(int64(i))
}
func (vs *Visitor) OnInt16(i int16) error {
return vs.onInt(int64(i))
}
func (vs *Visitor) OnInt32(i int32) error {
return vs.onInt(int64(i))
}
func (vs *Visitor) OnInt64(i int64) error {
return vs.onInt(i)
}
func (vs *Visitor) OnInt(i int) error {
return vs.onInt(int64(i))
}
func (vs *Visitor) onInt(v int64) error {
if err := vs.tryElemNext(); err != nil {
return err
}
/*
b := strconv.AppendInt(vs.scratch[:0], i, 10)
_, err := vs.w.Write(b)
*/
vs.onNumber(v < 0, uint64(v))
return nil
}
func (vs *Visitor) OnUint8(u uint8) error {
return vs.onUint(uint64(u))
}
func (vs *Visitor) OnByte(b byte) error {
return vs.onUint(uint64(b))
}
func (vs *Visitor) OnUint16(u uint16) error {
return vs.onUint(uint64(u))
}
func (vs *Visitor) OnUint32(u uint32) error {
return vs.onUint(uint64(u))
}
func (vs *Visitor) OnUint64(u uint64) error {
return vs.onUint(u)
}
func (vs *Visitor) OnUint(u uint) error {
return vs.onUint(uint64(u))
}
func (vs *Visitor) onUint(u uint64) error {
if err := vs.tryElemNext(); err != nil {
return err
}
return vs.onNumber(false, u)
/*
b := strconv.AppendUint(vs.scratch[:0], u, 10)
_, err := vs.w.Write(b)
return err
*/
}
func (vs *Visitor) onNumber(neg bool, u uint64) error {
if neg {
u = -u
}
i := len(vs.scratch)
// common case: use constants for / because
// the compiler can optimize it into a multiply+shift
if ^uintptr(0)>>32 == 0 {
for u > uint64(^uintptr(0)) {
q := u / 1e9
us := uintptr(u - q*1e9) // us % 1e9 fits into a uintptr
for j := 9; j > 0; j-- {
i--
qs := us / 10
vs.scratch[i] = byte(us - qs*10 + '0')
us = qs
}
u = q
}
}
// u guaranteed to fit into a uintptr
us := uintptr(u)
for us >= 10 {
i--
q := us / 10
vs.scratch[i] = byte(us - q*10 + '0')
us = q
}
// u < 10
i--
vs.scratch[i] = byte(us + '0')
if neg {
i--
vs.scratch[i] = '-'
}
return vs.w.write(vs.scratch[i:])
}
func (vs *Visitor) OnFloat32(f float32) error {
return vs.onFloat(float64(f), 32)
}
func (vs *Visitor) OnFloat64(f float64) error {
return vs.onFloat(f, 64)
}
func (vs *Visitor) onFloat(f float64, bits int) error {
if err := vs.tryElemNext(); err != nil {
return err
}
if math.IsInf(f, 0) || math.IsNaN(f) {
if !vs.ignoreInvalidFloat {
return fmt.Errorf("unsupported float value: %v", f)
}
return vs.w.write(nullSymbol)
}
b := strconv.AppendFloat(vs.scratch[:0], f, 'g', -1, bits)
if vs.explicitRadixPoint {
// b can be in either decimal form or scientific notation.
// For decimal form, append ".0" if radix point '.' is not present in the encoded number.
// e.g. 1 becomes 1.0.
// For scientific notation, append ".0" to mantissa if radix point '.' is not present in the encoded mantissa.
// e.g. 1e+2 becomes 1.0e+2.
needDp := true
expIdx := len(b)
loop:
for i, c := range b {
switch c {
case 'e': // exponent separator
expIdx = i
break loop
case '.': // decimal point
needDp = false
break loop
}
}
if err := vs.w.write(b[:expIdx]); err != nil {
return err
}
if needDp {
if err := vs.w.write([]byte(".0")); err != nil {
return err
}
}
return vs.w.write(b[expIdx:])
}
return vs.w.write(b)
}
func (s *boolStack) init() {
s.stack = make([]bool, 0, 32)
}
func (s *boolStack) push(b bool) {
s.stack = append(s.stack, s.current)
s.current = b
}
func (s *boolStack) pop() {
if len(s.stack) == 0 {
panic("pop from empty stack")
}
last := len(s.stack) - 1
s.current = s.stack[last]
s.stack = s.stack[:last]
}