tools/test-reader/reader/reader.go (386 lines of code) (raw):

package reader import ( "fmt" "go/ast" "go/parser" "go/token" "os" "path/filepath" "regexp" "strconv" "strings" "github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2/gohcl" "github.com/hashicorp/hcl/v2/hclparse" "github.com/hashicorp/hcl/v2/hclsyntax" ) type Resource map[string]any // config of one resource in a test type Resources map[string]Resource // map of resource names to resource configs type Step map[string]Resources // map of resource types to resources of that type type Test struct { Name string Steps []Step } func (t *Test) String() string { return fmt.Sprintf("%s: %#v", t.Name, t.Steps) } // Return a slice of tests as well as a map of file or test names to errors encountered. func ReadAllTests(servicesDir string) ([]*Test, map[string]error) { dirs, err := os.ReadDir(servicesDir) if err != nil { return nil, map[string]error{servicesDir: err} } allTests := make([]*Test, 0) allErrs := make(map[string]error) for _, dir := range dirs { servicePath := filepath.Join(servicesDir, dir.Name()) files, err := os.ReadDir(servicePath) if err != nil { return nil, map[string]error{servicePath: err} } var testFileNames []string for _, file := range files { if strings.HasSuffix(file.Name(), "_test.go") { testFileNames = append(testFileNames, filepath.Join(servicePath, file.Name())) } } serviceTests, serviceErrs := ReadTestFiles(testFileNames) for fileName, err := range serviceErrs { allErrs[fileName] = err } allTests = append(allTests, serviceTests...) } if len(allErrs) > 0 { return allTests, allErrs } return allTests, nil } // Read all the test files in a service directory together to capture cross-file function usage. func ReadTestFiles(filenames []string) ([]*Test, map[string]error) { funcDecls := make(map[string]*ast.FuncDecl) // map of function names to function declarations varDecls := make(map[string]*ast.BasicLit) // map of variable names to value expressions errs := make(map[string]error) // map of file or test names to errors encountered parsing fset := token.NewFileSet() for _, filename := range filenames { f, err := parser.ParseFile(fset, filename, nil, 0) if err != nil { errs[filename] = err continue } for _, decl := range f.Decls { if funcDecl, ok := decl.(*ast.FuncDecl); ok { // This is a function declaration. funcDecls[funcDecl.Name.Name] = funcDecl } else if genDecl, ok := decl.(*ast.GenDecl); ok { // This is an import, constant, type, or variable declaration for _, spec := range genDecl.Specs { if valueSpec, ok := spec.(*ast.ValueSpec); ok { if len(valueSpec.Values) > 0 { if basicLit, ok := valueSpec.Values[0].(*ast.BasicLit); ok { varDecls[valueSpec.Names[0].Name] = basicLit } } } } } } } tests := make([]*Test, 0) for name, funcDecl := range funcDecls { if strings.HasPrefix(name, "TestAcc") { funcTests, err := readTestFunc(funcDecl, funcDecls, varDecls) if err != nil { errs[name] = err } tests = append(tests, funcTests...) } } if len(errs) > 0 { return tests, errs } return tests, nil } func readTestFunc(testFunc *ast.FuncDecl, funcDecls map[string]*ast.FuncDecl, varDecls map[string]*ast.BasicLit) ([]*Test, error) { // This is an exported test function. var tests []*Test var errs []error vars := make(map[string]*ast.CompositeLit, len(testFunc.Body.List)) // map of variable names to composite literal values in function body for _, stmt := range testFunc.Body.List { if exprStmt, ok := stmt.(*ast.ExprStmt); ok { if callExpr, ok := exprStmt.X.(*ast.CallExpr); ok { // This is a call expression. ident, isIdent := callExpr.Fun.(*ast.Ident) selExpr, isSelExpr := callExpr.Fun.(*ast.SelectorExpr) if isIdent && ident.Name == "VcrTest" || isSelExpr && selExpr.Sel.Name == "VcrTest" { test, err := readVcrTestCall(callExpr, funcDecls, varDecls) if err != nil { errs = append(errs, err) } test.Name = testFunc.Name.Name tests = append(tests, test) } } } else if assignStmt, ok := stmt.(*ast.AssignStmt); ok { if len(assignStmt.Lhs) == 1 && len(assignStmt.Rhs) == 1 { // For now, only allow single assignment variables for serial test maps. // e.g. testCases := map[string]func(t *testing.T) {... if ident, ok := assignStmt.Lhs[0].(*ast.Ident); ok { if rhsCompLit, ok := assignStmt.Rhs[0].(*ast.CompositeLit); ok { vars[ident.Name] = rhsCompLit } } } } else if rangeStmt, ok := stmt.(*ast.RangeStmt); ok { if ident, ok := rangeStmt.X.(*ast.Ident); ok { if varCompLit, ok := vars[ident.Name]; ok { serialTests, serialErrs := readSerialTestCompLit(varCompLit, funcDecls, varDecls) errs = append(errs, serialErrs...) tests = append(tests, serialTests...) } } } } if len(errs) > 0 { return tests, fmt.Errorf("errors reading test func %s: %v", testFunc.Name.Name, errs) } return tests, nil } // Reads a composite literal which is either a slice or a map of serialized test functions. func readSerialTestCompLit(varCompLit *ast.CompositeLit, funcDecls map[string]*ast.FuncDecl, varDecls map[string]*ast.BasicLit) ([]*Test, []error) { var tests []*Test var errs []error for _, elt := range varCompLit.Elts { if eltKeyValueExpr, ok := elt.(*ast.KeyValueExpr); ok { eltTests, err := readSerialTestEltKeyValueExpr(eltKeyValueExpr, funcDecls, varDecls) if err != nil { errs = append(errs, err) } tests = append(tests, eltTests...) } } return tests, errs } func readSerialTestEltKeyValueExpr(eltKeyValueExpr *ast.KeyValueExpr, funcDecls map[string]*ast.FuncDecl, varDecls map[string]*ast.BasicLit) ([]*Test, error) { if ident, ok := eltKeyValueExpr.Value.(*ast.Ident); ok { if testFunc, ok := funcDecls[ident.Name]; ok { return readTestFunc(testFunc, funcDecls, varDecls) } return nil, fmt.Errorf("failed to find function with name %s", ident.Name) } return nil, fmt.Errorf("element key value expression with key %+v had non-ident value %+v", eltKeyValueExpr.Key, eltKeyValueExpr.Value) } func readVcrTestCall(vcrTestCall *ast.CallExpr, funcDecls map[string]*ast.FuncDecl, varDecls map[string]*ast.BasicLit) (*Test, error) { for _, arg := range vcrTestCall.Args { if vcrTestArgCompLit, ok := arg.(*ast.CompositeLit); ok { if selExpr, ok := vcrTestArgCompLit.Type.(*ast.SelectorExpr); ok { if ident, ok := selExpr.X.(*ast.Ident); ok && ident.Name == "resource" && selExpr.Sel.Name == "TestCase" { return readTestCaseCompLit(vcrTestArgCompLit, funcDecls, varDecls) } } } } return nil, fmt.Errorf("failed to find TestCase in %v", vcrTestCall.Args) } func readTestCaseCompLit(testCaseCompLit *ast.CompositeLit, funcDecls map[string]*ast.FuncDecl, varDecls map[string]*ast.BasicLit) (*Test, error) { for _, elt := range testCaseCompLit.Elts { if keyValueExpr, ok := elt.(*ast.KeyValueExpr); ok { if ident, ok := keyValueExpr.Key.(*ast.Ident); ok && ident.Name == "Steps" { if stepsCompLit, ok := keyValueExpr.Value.(*ast.CompositeLit); ok { return readStepsCompLit(stepsCompLit, funcDecls, varDecls) } } } } return nil, fmt.Errorf("failed to find Steps in %v", testCaseCompLit.Elts) } func readStepsCompLit(stepsCompLit *ast.CompositeLit, funcDecls map[string]*ast.FuncDecl, varDecls map[string]*ast.BasicLit) (*Test, error) { test := &Test{} errs := make([]error, 0) for _, elt := range stepsCompLit.Elts { if eltCompLit, ok := elt.(*ast.CompositeLit); ok { for _, eltCompLitElt := range eltCompLit.Elts { if keyValueExpr, ok := eltCompLitElt.(*ast.KeyValueExpr); ok { if ident, ok := keyValueExpr.Key.(*ast.Ident); ok && ident.Name == "Config" { var configStr string var err error if configCallExpr, ok := keyValueExpr.Value.(*ast.CallExpr); ok { configStr, err = readConfigCallExpr(configCallExpr, funcDecls, varDecls) } else if ident, ok := keyValueExpr.Value.(*ast.Ident); ok { if configVar, ok := varDecls[ident.Name]; ok { configStr, err = strconv.Unquote(configVar.Value) } } if err != nil { errs = append(errs, err) } step, err := readConfigStr(configStr) if err != nil { errs = append(errs, err) } test.Steps = append(test.Steps, step) } } } } } if len(errs) > 0 { return test, fmt.Errorf("errors reading test steps: %v", errs) } return test, nil } // Read the call expression in the public test function that returns the config. func readConfigCallExpr(configCallExpr *ast.CallExpr, funcDecls map[string]*ast.FuncDecl, varDecls map[string]*ast.BasicLit) (string, error) { if ident, ok := configCallExpr.Fun.(*ast.Ident); ok { if configFunc, ok := funcDecls[ident.Name]; ok { return readConfigFunc(configFunc, funcDecls, varDecls) } return "", fmt.Errorf("failed to find function declaration %s", ident.Name) } return "", fmt.Errorf("failed to get ident for %v", configCallExpr.Fun) } func readConfigFunc(configFunc *ast.FuncDecl, funcDecls map[string]*ast.FuncDecl, varDecls map[string]*ast.BasicLit) (string, error) { for _, stmt := range configFunc.Body.List { if returnStmt, ok := stmt.(*ast.ReturnStmt); ok { if len(returnStmt.Results) > 0 { return readConfigFuncResult(returnStmt.Results[0], funcDecls, varDecls) } return "", fmt.Errorf("failed to find a config string in results %v", returnStmt.Results) } } return "", fmt.Errorf("failed to find a return statement in %v", configFunc.Body.List) } // Read the return result of a config func and return the config string. func readConfigFuncResult(result ast.Expr, funcDecls map[string]*ast.FuncDecl, varDecls map[string]*ast.BasicLit) (string, error) { if basicLit, ok := result.(*ast.BasicLit); ok && basicLit.Kind == token.STRING { return strconv.Unquote(basicLit.Value) } else if callExpr, ok := result.(*ast.CallExpr); ok { return readConfigFuncCallExpr(callExpr, funcDecls, varDecls) } else if binaryExpr, ok := result.(*ast.BinaryExpr); ok { xConfigStr, err := readConfigFuncResult(binaryExpr.X, funcDecls, varDecls) if err != nil { return "", err } yConfigStr, err := readConfigFuncResult(binaryExpr.Y, funcDecls, varDecls) if err != nil { return "", err } return xConfigStr + yConfigStr, nil } return "", fmt.Errorf("unknown config func result %v (%T)", result, result) } // Read the call expression in the config function that returns the config string. // The call expression can contain a nested call expression. // Return the config string. func readConfigFuncCallExpr(configFuncCallExpr *ast.CallExpr, funcDecls map[string]*ast.FuncDecl, varDecls map[string]*ast.BasicLit) (string, error) { if len(configFuncCallExpr.Args) > 0 { if basicLit, ok := configFuncCallExpr.Args[0].(*ast.BasicLit); ok && basicLit.Kind == token.STRING { return strconv.Unquote(basicLit.Value) } else if nestedCallExpr, ok := configFuncCallExpr.Args[0].(*ast.CallExpr); ok { return readConfigFuncCallExpr(nestedCallExpr, funcDecls, varDecls) } } // Config string not readable from args, attempt to read call expression as a helper function. return readConfigCallExpr(configFuncCallExpr, funcDecls, varDecls) } var subPattern = regexp.MustCompile("%({[^{}]*}|[vTtbcspqxXUeEfFgGdo])") // Read the config string and return a test step. func readConfigStr(configStr string) (Step, error) { // Remove fmt substitutions because they interfere with hcl parsing. // Replace with a value that can be parsed outside quotation marks. configStr = subPattern.ReplaceAllString(configStr, "true") parser := hclparse.NewParser() file, diagnostics := parser.ParseHCL([]byte(configStr), "config.hcl") if diagnostics.HasErrors() { return nil, fmt.Errorf("errors parsing hcl: %v", diagnostics.Errs()) } content, diagnostics := file.Body.Content(&hcl.BodySchema{ Blocks: []hcl.BlockHeaderSchema{ { Type: "resource", LabelNames: []string{"type", "name"}, }, { Type: "data", LabelNames: []string{"type", "name"}, }, { Type: "output", LabelNames: []string{"name"}, }, { Type: "provider", LabelNames: []string{"name"}, }, { Type: "locals", }, }, }) if diagnostics.HasErrors() { return nil, fmt.Errorf("errors getting hcl body content: %v", diagnostics.Errs()) } m := make(map[string]Resources) errs := make([]error, 0) for _, block := range content.Blocks { if len(block.Labels) != 2 { continue } if _, ok := m[block.Labels[0]]; !ok { // Create an empty map for this resource type. m[block.Labels[0]] = make(Resources) } // Use the resource name as a key. resourceConfig, err := readHCLBlockBody(block.Body, file.Bytes) if err != nil { errs = append(errs, err) } resourceConfig = flattenResource(resourceConfig, "") m[block.Labels[0]][block.Labels[1]] = resourceConfig } if len(errs) > 0 { return m, fmt.Errorf("errors reading hcl blocks: %v", errs) } return m, nil } func readHCLBlockBody(body hcl.Body, fileBytes []byte) (Resource, error) { var m Resource gohcl.DecodeBody(body, nil, &m) for k, v := range m { if attr, ok := v.(*hcl.Attribute); ok { m[k] = string(attr.Expr.Range().SliceBytes(fileBytes)) } } syntaxBody, ok := body.(*hclsyntax.Body) if !ok { return m, fmt.Errorf("couldn't get hclsyntax body from %v", body) } errs := make([]error, 0) for _, block := range syntaxBody.Blocks { blockConfig, err := readHCLBlockBody(block.Body, fileBytes) if err != nil { errs = append(errs, err) } if existing, ok := m[block.Type]; ok { // Merge the fields from the current block into the existing resource config. if existingResource, ok := existing.(Resource); ok { mergeResources(existingResource, blockConfig) } } else { m[block.Type] = blockConfig } } if len(errs) > 0 { return m, fmt.Errorf("errors reading hcl blocks: %v", errs) } return m, nil } // Perform a recursive one-way merge of b into a. func mergeResources(a, b Resource) { for k, bv := range b { if av, ok := a[k]; ok { if avr, ok := av.(Resource); ok { if bvr, ok := bv.(Resource); ok { mergeResources(avr, bvr) } } } else { a[k] = bv } } } func flattenResource(r Resource, parent string) Resource { flattened := make(Resource) if parent != "" { parent += "." } for fieldName, fieldValue := range r { key := parent + fieldName nestedObject, _ := fieldValue.(Resource) if nestedObject != nil { for childKey, childField := range flattenResource(nestedObject, key) { flattened[childKey] = childField } } else { flattened[key] = fieldValue } } return flattened }