func genTestMain()

in go/tools/builders/generate_test_main.go [212:392]


func genTestMain(args []string) error {
	// Prepare our flags
	args, err := expandParamsFiles(args)
	if err != nil {
		return err
	}
	imports := multiFlag{}
	sources := multiFlag{}
	flags := flag.NewFlagSet("GoTestGenTest", flag.ExitOnError)
	goenv := envFlags(flags)
	out := flags.String("output", "", "output file to write. Defaults to stdout.")
	coverMode := flags.String("cover_mode", "", "the coverage mode to use")
	pkgname := flags.String("pkgname", "", "package name of test")
	flags.Var(&imports, "import", "Packages to import")
	flags.Var(&sources, "src", "Sources to process for tests")
	if err := flags.Parse(args); err != nil {
		return err
	}
	if err := goenv.checkFlags(); err != nil {
		return err
	}
	// Process import args
	importMap := map[string]*Import{}
	for _, imp := range imports {
		parts := strings.Split(imp, "=")
		if len(parts) != 2 {
			return fmt.Errorf("Invalid import %q specified", imp)
		}
		i := &Import{Name: parts[0], Path: parts[1]}
		importMap[i.Name] = i
	}
	// Process source args
	sourceList := []string{}
	sourceMap := map[string]string{}
	for _, s := range sources {
		parts := strings.Split(s, "=")
		if len(parts) != 2 {
			return fmt.Errorf("Invalid source %q specified", s)
		}
		sourceList = append(sourceList, parts[1])
		sourceMap[parts[1]] = parts[0]
	}

	// filter our input file list
	filteredSrcs, err := filterAndSplitFiles(sourceList)
	if err != nil {
		return err
	}
	goSrcs := filteredSrcs.goSrcs

	outFile := os.Stdout
	if *out != "" {
		var err error
		outFile, err = os.Create(*out)
		if err != nil {
			return fmt.Errorf("os.Create(%q): %v", *out, err)
		}
		defer outFile.Close()
	}

	cases := Cases{
		CoverMode: *coverMode,
		Pkgname:   *pkgname,
	}

	testFileSet := token.NewFileSet()
	pkgs := map[string]bool{}
	for _, f := range goSrcs {
		parse, err := parser.ParseFile(testFileSet, f.filename, nil, parser.ParseComments)
		if err != nil {
			return fmt.Errorf("ParseFile(%q): %v", f.filename, err)
		}
		pkg := sourceMap[f.filename]
		if strings.HasSuffix(parse.Name.String(), "_test") {
			pkg += "_test"
		}
		for _, e := range doc.Examples(parse) {
			if e.Output == "" && !e.EmptyOutput {
				continue
			}
			cases.Examples = append(cases.Examples, Example{
				Name:      "Example" + e.Name,
				Package:   pkg,
				Output:    e.Output,
				Unordered: e.Unordered,
			})
			pkgs[pkg] = true
		}
		for _, d := range parse.Decls {
			fn, ok := d.(*ast.FuncDecl)
			if !ok {
				continue
			}
			if fn.Recv != nil {
				continue
			}
			if fn.Name.Name == "TestMain" {
				// TestMain is not, itself, a test
				pkgs[pkg] = true
				cases.TestMain = fmt.Sprintf("%s.%s", pkg, fn.Name.Name)
				continue
			}

			// Here we check the signature of the Test* function. To
			// be considered a test:

			// 1. The function should have a single argument.
			if len(fn.Type.Params.List) != 1 {
				continue
			}

			// 2. The function should return nothing.
			if fn.Type.Results != nil {
				continue
			}

			// 3. The only parameter should have a type identified as
			//    *<something>.T
			starExpr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
			if !ok {
				continue
			}
			selExpr, ok := starExpr.X.(*ast.SelectorExpr)
			if !ok {
				continue
			}

			// We do not descriminate on the referenced type of the
			// parameter being *testing.T. Instead we assert that it
			// should be *<something>.T. This is because the import
			// could have been aliased as a different identifier.

			if strings.HasPrefix(fn.Name.Name, "Test") {
				if selExpr.Sel.Name != "T" {
					continue
				}
				pkgs[pkg] = true
				cases.Tests = append(cases.Tests, TestCase{
					Package: pkg,
					Name:    fn.Name.Name,
				})
			}
			if strings.HasPrefix(fn.Name.Name, "Benchmark") {
				if selExpr.Sel.Name != "B" {
					continue
				}
				pkgs[pkg] = true
				cases.Benchmarks = append(cases.Benchmarks, TestCase{
					Package: pkg,
					Name:    fn.Name.Name,
				})
			}
			if strings.HasPrefix(fn.Name.Name, "Fuzz") {
				if selExpr.Sel.Name != "F" {
					continue
				}
				pkgs[pkg] = true
				cases.FuzzTargets = append(cases.FuzzTargets, TestCase{
					Package: pkg,
					Name:    fn.Name.Name,
				})
			}
		}
	}

	for name := range importMap {
		// Set the names for all unused imports to "_"
		if !pkgs[name] {
			importMap[name].Name = "_"
		}
		cases.Imports = append(cases.Imports, importMap[name])
	}
	sort.Slice(cases.Imports, func(i, j int) bool {
		return cases.Imports[i].Name < cases.Imports[j].Name
	})
	tpl := template.Must(template.New("source").Parse(testMainTpl))
	if err := tpl.Execute(outFile, &cases); err != nil {
		return fmt.Errorf("template.Execute(%v): %v", cases, err)
	}
	return nil
}