vulndb/sqlutil/b64schema/main.go (109 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// b64schema converts a SQL schema file into base64 encoded strings as Go code.
package main
import (
"bufio"
"bytes"
"encoding/base64"
"flag"
"fmt"
"os"
"regexp"
"strings"
"text/template"
)
func main() {
flag.Usage = func() {
fmt.Printf("use: %s [flags] input.sql output.go\n", os.Args[0])
flag.PrintDefaults()
os.Exit(1)
}
pkg := flag.String("pkg", "schema", "set package name")
flag.Parse()
if len(flag.Args()) != 2 {
flag.Usage()
}
f, err := os.Open(flag.Arg(0))
if err != nil {
panic(err)
}
defer f.Close()
o, err := os.Create(flag.Arg(1))
if err != nil {
panic(err)
}
defer o.Close()
var b bytes.Buffer
var stmt []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if b.Len() == 0 && line == "" {
continue
}
b.WriteString(line)
b.WriteString("\n")
if strings.HasSuffix(line, ";") {
s := base64.StdEncoding.EncodeToString(b.Bytes())
stmt = append(stmt, s)
b.Reset()
}
}
if len(stmt) == 0 {
panic("empty stmt")
}
t := gotype(f.Name())
decoderTemplate.Execute(o, struct {
Pkg string
Pub string
File string
}{
*pkg, t, f.Name(),
})
fmt.Fprintf(o, "// b64%s is auto-generated from %s.\n", t, f.Name())
fmt.Fprintf(o, "var b64%s = []string{", t)
fmt.Fprintf(o, "%q", stmt[0])
for i := 1; i < len(stmt); i++ {
fmt.Fprintf(o, ", %q", stmt[i])
}
fmt.Fprintf(o, "}\n")
}
var decoderTemplate = template.Must(template.New("decoder").
Parse(`package {{.Pkg}}
import (
"context"
"database/sql"
"encoding/base64"
)
// Init{{.Pub}} is auto-generated. Executes each SQL statement from {{.File}}.
func Init{{.Pub}}(ctx context.Context, db *sql.DB) error {
for _, stmt := range {{.Pub}}() {
_, err := db.ExecContext(ctx, stmt)
if err != nil {
return err
}
}
return nil
}
// {{.Pub}} is auto-generated. Returns each SQL statement from {{.File}}.
func {{.Pub}}() []string {
s := make([]string, len(b64{{.Pub}}))
for i := 0; i < len(s); i++ {
v, _ := base64.StdEncoding.DecodeString(b64{{.Pub}}[i])
s[i] = string(v)
}
return s
}
`))
var gotyper = regexp.MustCompile("[0-9A-Za-z]+")
func gotype(name string) string {
parts := gotyper.FindAllString(name, -1)
name = ""
for _, part := range parts {
p := strings.ToUpper(part)
if _, exists := commonInitialisms[p]; exists {
name += p
} else {
name += strings.Title(part)
}
}
return name
}