vulndb/sqlutil/record.go (135 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlutil
import (
"fmt"
"reflect"
"unicode"
)
// Record represents a database record.
// See RecordType for mapping structs to Records.
type Record interface {
Subset(...string) Record // Subset returns a subset of the record fields.
Fields() []string // Fields return the struct field name to be used as table column
Values() []interface{} // Values return the struct field value to be used in Scan
}
// Records implements the Record interface for a slice of Records.
type Records []Record
// NewRecords creates and initializes new Records from slice of any struct.
func NewRecords(T interface{}) Records {
var r Records
walkSlice(T, func(fv reflect.Value) {
r = append(r, NewRecordType(fv.Interface()))
})
return r
}
// Subset returns Records with a subset of their fields.
//
// In order to use the Subset of Records (a []Record), one needs to type
// assert it: rows.Scan(NewRecords(&r).Subset("a", "b").(Records)...).
func (r Records) Subset(fields ...string) Record {
s := make(Records, len(r))
for i := 0; i < len(r); i++ {
s[i] = r[i].Subset(fields...)
}
return s
}
// Fields returns a list of table column names from the first Record.
func (r Records) Fields() []string {
if len(r) == 0 {
return nil
}
return r[0].Fields()
}
// Values returns a list of table column values from all Records.
func (r Records) Values() []interface{} {
if len(r) == 0 {
return nil
}
v := make([]interface{}, 0, len(r)*len(r[0].Fields()))
for _, record := range r {
v = append(v, record.Values()...)
}
return v
}
// RecordType implements the Record interface for any struct with 'sql' tags as T.
type RecordType struct {
T interface{}
}
// NewRecordType creates and initializes a new RecordType.
func NewRecordType(T interface{}) RecordType {
return RecordType{T}
}
// Subset returns a subset of the struct fields from r.T.
func (r RecordType) Subset(fields ...string) Record {
return newRecordSubset(r.Fields(), r.Values(), fields)
}
// Fields returns a list of table column names from struct fields.
// Uses the 'db' tag if available.
func (r RecordType) Fields() []string {
var f []string
const tag = "sql"
walkStruct(r.T, func(ft reflect.StructField, fv reflect.Value) {
n := ft.Tag.Get(tag)
if n == "" {
n = ft.Name
}
f = append(f, n)
})
return f
}
// Values returns a list of table column values from struct fields.
// These values are suitable for Row.Scan from Exec (e.g. INSERT, REPLACE) calls.
func (r RecordType) Values() []interface{} {
var v []interface{}
walkStruct(r.T, func(ft reflect.StructField, fv reflect.Value) {
if fv.CanAddr() {
v = append(v, fv.Addr().Interface())
} else {
v = append(v, fv.Interface())
}
})
return v
}
type recordSubset struct {
cols []string
vals []interface{}
}
func newRecordSubset(cols []string, vals []interface{}, fields []string) *recordSubset {
if len(cols) != len(vals) {
panic("invalid cols/vals have different length")
}
inSubset := func(s string) bool {
for _, f := range fields {
if s == f {
return true
}
}
return false
}
rs := &recordSubset{}
for i, col := range cols {
if inSubset(col) {
rs.cols = append(rs.cols, col)
rs.vals = append(rs.vals, vals[i])
}
}
return rs
}
func (rs *recordSubset) Subset(fields ...string) Record {
return newRecordSubset(rs.cols, rs.vals, fields)
}
func (rs *recordSubset) Fields() []string { return rs.cols }
func (rs *recordSubset) Values() []interface{} { return rs.vals }
// walkStruct walks the exported fields of a struct using reflection,
// and calls fn for each field.
func walkStruct(s interface{}, fn func(reflect.StructField, reflect.Value)) {
t := reflect.TypeOf(s)
v := reflect.ValueOf(s)
if v.Kind() == reflect.Ptr {
v = v.Elem()
t = t.Elem()
}
if v.Kind() != reflect.Struct {
panic(fmt.Sprintf("walkStruct: s is not struct: %T", s))
}
for i := 0; i < v.NumField(); i++ {
ft := t.Field(i)
if !exportedField(ft) {
continue
}
fn(ft, v.Field(i))
}
}
func exportedField(f reflect.StructField) bool {
return !f.Anonymous && unicode.IsUpper(rune(f.Name[0]))
}
// walkSlice walks slice s using reflection, and calls fn for each element.
func walkSlice(s interface{}, fn func(reflect.Value)) {
t := reflect.TypeOf(s)
v := reflect.ValueOf(s)
if t.Kind() != reflect.Slice {
panic(fmt.Sprintf("walkSlice: s is not slice: %T", s))
}
for i := 0; i < v.Len(); i++ {
fn(v.Index(i))
}
}