private/model/api/smoke.go (176 lines of code) (raw):
//go:build codegen
// +build codegen
package api
import (
"bytes"
"encoding/json"
"fmt"
"os"
"sort"
"text/template"
)
// SmokeTestSuite defines the test suite for smoke tests.
type SmokeTestSuite struct {
Version int `json:"version"`
DefaultRegion string `json:"defaultRegion"`
TestCases []SmokeTestCase `json:"testCases"`
}
// SmokeTestCase provides the definition for a integration smoke test case.
type SmokeTestCase struct {
OpName string `json:"operationName"`
Input map[string]interface{} `json:"input"`
ExpectErr bool `json:"errorExpectedFromService"`
}
var smokeTestsCustomizations = map[string]func(*SmokeTestSuite) error{
"sts": stsSmokeTestCustomization,
"waf": wafSmokeTestCustomization,
"wafregional": wafRegionalSmokeTestCustomization,
"iotdataplane": iotDataPlaneSmokeTestCustomization,
"opsworks": removeSmokeTests,
"cloudsearch": removeSmokeTests,
}
func iotDataPlaneSmokeTestCustomization(suite *SmokeTestSuite) error {
suite.TestCases = []SmokeTestCase{}
return nil
}
func removeSmokeTests(suite *SmokeTestSuite) error {
suite.TestCases = []SmokeTestCase{}
return nil
}
func wafSmokeTestCustomization(suite *SmokeTestSuite) error {
return filterWAFCreateSqlInjectionMatchSet(suite)
}
func wafRegionalSmokeTestCustomization(suite *SmokeTestSuite) error {
return filterWAFCreateSqlInjectionMatchSet(suite)
}
func filterWAFCreateSqlInjectionMatchSet(suite *SmokeTestSuite) error {
const createSqlInjectionMatchSetOp = "CreateSqlInjectionMatchSet"
var testCases []SmokeTestCase
for _, testCase := range suite.TestCases {
if testCase.OpName == createSqlInjectionMatchSetOp {
continue
}
testCases = append(testCases, testCase)
}
suite.TestCases = testCases
return nil
}
func stsSmokeTestCustomization(suite *SmokeTestSuite) error {
const getSessionTokenOp = "GetSessionToken"
const getCallerIdentityOp = "GetCallerIdentity"
opTestMap := make(map[string][]SmokeTestCase)
for _, testCase := range suite.TestCases {
opTestMap[testCase.OpName] = append(opTestMap[testCase.OpName], testCase)
}
if _, ok := opTestMap[getSessionTokenOp]; ok {
delete(opTestMap, getSessionTokenOp)
}
if _, ok := opTestMap[getCallerIdentityOp]; !ok {
opTestMap[getCallerIdentityOp] = append(opTestMap[getCallerIdentityOp], SmokeTestCase{
OpName: getCallerIdentityOp,
Input: map[string]interface{}{},
ExpectErr: false,
})
}
var testCases []SmokeTestCase
var keys []string
for name := range opTestMap {
keys = append(keys, name)
}
sort.Strings(keys)
for _, name := range keys {
testCases = append(testCases, opTestMap[name]...)
}
suite.TestCases = testCases
return nil
}
// BuildInputShape returns the Go code as a string for initializing the test
// case's input shape.
func (c SmokeTestCase) BuildInputShape(ref *ShapeRef) string {
b := NewShapeValueBuilder()
return fmt.Sprintf("&%s{\n%s\n}",
b.GoType(ref, true),
b.BuildShape(ref, c.Input, false),
)
}
// AttachSmokeTests attaches the smoke test cases to the API model.
func (a *API) AttachSmokeTests(filename string) error {
f, err := os.Open(filename)
if err != nil {
return fmt.Errorf("failed to open smoke tests %s, err: %v", filename, err)
}
defer f.Close()
if err := json.NewDecoder(f).Decode(&a.SmokeTests); err != nil {
return fmt.Errorf("failed to decode smoke tests %s, err: %v", filename, err)
}
if v := a.SmokeTests.Version; v != 1 {
return fmt.Errorf("invalid smoke test version, %d", v)
}
if fn, ok := smokeTestsCustomizations[a.PackageName()]; ok {
if err := fn(&a.SmokeTests); err != nil {
return err
}
}
return nil
}
// APISmokeTestsGoCode returns the Go Code string for the smoke tests.
func (a *API) APISmokeTestsGoCode() string {
w := bytes.NewBuffer(nil)
a.resetImports()
a.AddImport("context")
a.AddImport("testing")
a.AddImport("time")
a.AddSDKImport("aws")
a.AddSDKImport("aws/request")
a.AddSDKImport("aws/awserr")
a.AddSDKImport("aws/request")
a.AddSDKImport("awstesting/integration")
a.AddImport(a.ImportPath())
smokeTests := struct {
API *API
SmokeTestSuite
}{
API: a,
SmokeTestSuite: a.SmokeTests,
}
if err := smokeTestTmpl.Execute(w, smokeTests); err != nil {
panic(fmt.Sprintf("failed to create smoke tests, %v", err))
}
ignoreImports := `
var _ aws.Config
var _ awserr.Error
var _ request.Request
`
return a.importsGoCode() + ignoreImports + w.String()
}
var smokeTestTmpl = template.Must(template.New(`smokeTestTmpl`).Parse(`
{{- range $i, $testCase := $.TestCases }}
{{- $op := index $.API.Operations $testCase.OpName }}
{{- if $op }}
func TestInteg_{{ printf "%02d" $i }}_{{ $op.ExportedName }}(t *testing.T) {
ctx, cancelFn := context.WithTimeout(context.Background(), 5 *time.Second)
defer cancelFn()
sess := integration.SessionWithDefaultRegion("{{ $.DefaultRegion }}")
svc := {{ $.API.PackageName }}.New(sess)
params := {{ $testCase.BuildInputShape $op.InputRef }}
_, err := svc.{{ $op.ExportedName }}WithContext(ctx, params, func(r *request.Request) {
r.Handlers.Validate.RemoveByName("core.ValidateParametersHandler")
})
{{- if $testCase.ExpectErr }}
if err == nil {
t.Fatalf("expect request to fail")
}
aerr, ok := err.(awserr.RequestFailure)
if !ok {
t.Fatalf("expect awserr, was %T", err)
}
if len(aerr.Code()) == 0 {
t.Errorf("expect non-empty error code")
}
if len(aerr.Message()) == 0 {
t.Errorf("expect non-empty error message")
}
if v := aerr.Code(); v == request.ErrCodeSerialization {
t.Errorf("expect API error code got serialization failure")
}
{{- else }}
if err != nil {
t.Errorf("expect no error, got %v", err)
}
{{- end }}
}
{{- end }}
{{- end }}
`))