in go/tools/builders/generate_test_main.go [212:392]
func genTestMain(args []string) error {
// Prepare our flags
args, err := expandParamsFiles(args)
if err != nil {
return err
}
imports := multiFlag{}
sources := multiFlag{}
flags := flag.NewFlagSet("GoTestGenTest", flag.ExitOnError)
goenv := envFlags(flags)
out := flags.String("output", "", "output file to write. Defaults to stdout.")
coverMode := flags.String("cover_mode", "", "the coverage mode to use")
pkgname := flags.String("pkgname", "", "package name of test")
flags.Var(&imports, "import", "Packages to import")
flags.Var(&sources, "src", "Sources to process for tests")
if err := flags.Parse(args); err != nil {
return err
}
if err := goenv.checkFlags(); err != nil {
return err
}
// Process import args
importMap := map[string]*Import{}
for _, imp := range imports {
parts := strings.Split(imp, "=")
if len(parts) != 2 {
return fmt.Errorf("Invalid import %q specified", imp)
}
i := &Import{Name: parts[0], Path: parts[1]}
importMap[i.Name] = i
}
// Process source args
sourceList := []string{}
sourceMap := map[string]string{}
for _, s := range sources {
parts := strings.Split(s, "=")
if len(parts) != 2 {
return fmt.Errorf("Invalid source %q specified", s)
}
sourceList = append(sourceList, parts[1])
sourceMap[parts[1]] = parts[0]
}
// filter our input file list
filteredSrcs, err := filterAndSplitFiles(sourceList)
if err != nil {
return err
}
goSrcs := filteredSrcs.goSrcs
outFile := os.Stdout
if *out != "" {
var err error
outFile, err = os.Create(*out)
if err != nil {
return fmt.Errorf("os.Create(%q): %v", *out, err)
}
defer outFile.Close()
}
cases := Cases{
CoverMode: *coverMode,
Pkgname: *pkgname,
}
testFileSet := token.NewFileSet()
pkgs := map[string]bool{}
for _, f := range goSrcs {
parse, err := parser.ParseFile(testFileSet, f.filename, nil, parser.ParseComments)
if err != nil {
return fmt.Errorf("ParseFile(%q): %v", f.filename, err)
}
pkg := sourceMap[f.filename]
if strings.HasSuffix(parse.Name.String(), "_test") {
pkg += "_test"
}
for _, e := range doc.Examples(parse) {
if e.Output == "" && !e.EmptyOutput {
continue
}
cases.Examples = append(cases.Examples, Example{
Name: "Example" + e.Name,
Package: pkg,
Output: e.Output,
Unordered: e.Unordered,
})
pkgs[pkg] = true
}
for _, d := range parse.Decls {
fn, ok := d.(*ast.FuncDecl)
if !ok {
continue
}
if fn.Recv != nil {
continue
}
if fn.Name.Name == "TestMain" {
// TestMain is not, itself, a test
pkgs[pkg] = true
cases.TestMain = fmt.Sprintf("%s.%s", pkg, fn.Name.Name)
continue
}
// Here we check the signature of the Test* function. To
// be considered a test:
// 1. The function should have a single argument.
if len(fn.Type.Params.List) != 1 {
continue
}
// 2. The function should return nothing.
if fn.Type.Results != nil {
continue
}
// 3. The only parameter should have a type identified as
// *<something>.T
starExpr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
if !ok {
continue
}
selExpr, ok := starExpr.X.(*ast.SelectorExpr)
if !ok {
continue
}
// We do not descriminate on the referenced type of the
// parameter being *testing.T. Instead we assert that it
// should be *<something>.T. This is because the import
// could have been aliased as a different identifier.
if strings.HasPrefix(fn.Name.Name, "Test") {
if selExpr.Sel.Name != "T" {
continue
}
pkgs[pkg] = true
cases.Tests = append(cases.Tests, TestCase{
Package: pkg,
Name: fn.Name.Name,
})
}
if strings.HasPrefix(fn.Name.Name, "Benchmark") {
if selExpr.Sel.Name != "B" {
continue
}
pkgs[pkg] = true
cases.Benchmarks = append(cases.Benchmarks, TestCase{
Package: pkg,
Name: fn.Name.Name,
})
}
if strings.HasPrefix(fn.Name.Name, "Fuzz") {
if selExpr.Sel.Name != "F" {
continue
}
pkgs[pkg] = true
cases.FuzzTargets = append(cases.FuzzTargets, TestCase{
Package: pkg,
Name: fn.Name.Name,
})
}
}
}
for name := range importMap {
// Set the names for all unused imports to "_"
if !pkgs[name] {
importMap[name].Name = "_"
}
cases.Imports = append(cases.Imports, importMap[name])
}
sort.Slice(cases.Imports, func(i, j int) bool {
return cases.Imports[i].Name < cases.Imports[j].Name
})
tpl := template.Must(template.New("source").Parse(testMainTpl))
if err := tpl.Execute(outFile, &cases); err != nil {
return fmt.Errorf("template.Execute(%v): %v", cases, err)
}
return nil
}