ecs-agent/gogenerate/awssdk.go (162 lines of code) (raw):
//go:build codegen
// +build codegen
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.
// Note: This package uses the 'codegen' directive for compatibility with
// the AWS SDK's code generation scripts.
package main
import (
"bytes"
"flag"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"strings"
"text/template"
"github.com/aws/aws-sdk-go/private/model/api"
"github.com/aws/aws-sdk-go/private/util"
"golang.org/x/tools/imports"
)
func main() {
os.Exit(_main())
}
func ShapeListWithErrors(a *api.API) []*api.Shape {
list := make([]*api.Shape, 0, len(a.Shapes))
for _, n := range a.ShapeNames() {
list = append(list, a.Shapes[n])
}
return list
}
var typesOnlyTplAPI = template.Must(template.New("api").Funcs(template.FuncMap{
"ShapeListWithErrors": ShapeListWithErrors,
}).Parse(`
{{ $shapeList := ShapeListWithErrors $ }}
{{ range $_, $s := $shapeList }}
{{ if eq $s.Type "structure"}}{{ $s.GoCode }}{{ end }}
{{ end }}
`))
var (
typesOnly bool
copyrightHeader string
)
func init() {
types := flag.Bool("typesOnly", false, "only generate types")
copyrightFile := flag.String("copyright_file", "", "Copyright file used to add copyright header")
flag.Parse()
if *copyrightFile == "" {
fmt.Println("copyright_file must be set")
os.Exit(1)
}
copyrightText, err := getCopyrightFile(*copyrightFile)
if err != nil {
fmt.Println("error reading copyright_file: ", err)
os.Exit(1)
}
copyrightText += "\nCode generated by [agent/gogenerate/awssdk.go] DO NOT EDIT."
b := &strings.Builder{}
for i, line := range strings.Split(copyrightText, "\n") {
if i != 0 {
b.WriteString("\n")
}
b.WriteString("//")
if line != "" {
b.WriteString(" ")
b.WriteString(line)
}
}
b.WriteString("\n")
copyrightHeader = b.String()
typesOnly = *types
}
// return-based exit code so the 'defer' works
func _main() int {
apiFile := "./api/api-2.json"
var err error
if typesOnly {
err = genTypesOnlyAPI(apiFile)
} else {
err = genFull(apiFile)
}
if err != nil {
fmt.Println(err)
return 1
}
return 0
}
func genTypesOnlyAPI(file string) error {
apiGen := &api.API{
NoRemoveUnusedShapes: true,
NoRenameToplevelShapes: true,
NoGenStructFieldAccessors: true,
}
apiGen.Attach(file)
apiGen.Setup()
// to reset imports so that timestamp has an entry in the map.
apiGen.APIGoCode()
var buf bytes.Buffer
err := typesOnlyTplAPI.Execute(&buf, apiGen)
if err != nil {
panic(err)
}
code := strings.TrimSpace(buf.String())
code = util.GoFmt(code)
// Ignore dir error, filepath will catch it for an invalid path.
os.Mkdir(apiGen.PackageName(), 0755)
// Fix imports.
codeWithImports, err := imports.Process("", []byte(fmt.Sprintf("package %s\n\n%s", apiGen.PackageName(), code)), nil)
if err != nil {
fmt.Println(err)
return err
}
outFile := filepath.Join(apiGen.PackageName(), "api.go")
err = ioutil.WriteFile(outFile, []byte(fmt.Sprintf("%s\n%s", copyrightHeader, codeWithImports)), 0644)
if err != nil {
return err
}
return nil
}
func genFull(file string) error {
// Heavily inspired by https://github.com/aws/aws-sdk-go/blob/45ba2f817fbf625fe48cf9b51fae13e5dfba6171/internal/model/cli/gen-api/main.go#L36
// That code is copyright Amazon
api := &api.API{}
api.Attach(file)
paginatorsFile := strings.Replace(file, "api-2.json", "paginators-1.json", -1)
if _, err := os.Stat(paginatorsFile); err == nil {
api.AttachPaginators(paginatorsFile)
}
docsFile := strings.Replace(file, "api-2.json", "docs-2.json", -1)
if _, err := os.Stat(docsFile); err == nil {
api.AttachDocs(docsFile)
}
api.Setup()
if err := genFile(api, api.APIGoCode(), "api.go"); err != nil {
return err
}
if err := genFile(api, api.ServiceGoCode(), "service.go"); err != nil {
return err
}
if err := genFile(api, api.APIErrorsGoCode(), "errors.go"); err != nil {
return err
}
return nil
}
func genFile(api *api.API, code string, fileName string) error {
processedCode, err := imports.Process("",
[]byte(fmt.Sprintf("package %s\n\n%s", api.PackageName(), code)),
nil)
if err != nil {
fmt.Println("Error processing file imports: ", err)
return err
}
// Ignore dir error, filepath will catch it for an invalid path.
os.Mkdir(api.PackageName(), 0755)
outFile, err := os.Create(filepath.Join(api.PackageName(), fileName))
if err != nil {
fmt.Println("Error creating file: ", err)
return err
}
defer outFile.Close()
withHeader := fmt.Sprintf("%s\n%s", copyrightHeader, processedCode)
goimports := exec.Command("goimports")
goimports.Stdin = bytes.NewBufferString(withHeader)
goimports.Stdout = outFile
err = goimports.Run()
if err != nil {
fmt.Println("Error running goimports: ", err)
return err
}
return nil
}
func getCopyrightFile(filename string) (string, error) {
contents, err := ioutil.ReadFile(filename)
return string(contents), err
}