codegen/mockgen.go (130 lines of code) (raw):

// Copyright (c) 2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package codegen import ( "bytes" "go/token" "os/exec" "path" "sort" "strconv" "strings" "github.com/golang/mock/mockgen/model" "github.com/pkg/errors" ) const ( mockgenPkg = "github.com/golang/mock/mockgen" ) // MockgenBin is a struct abstracts the mockgen binary built from mockgen package in vendor type MockgenBin struct { pkgHelper *PackageHelper tmpl *Template } // NewMockgenBin builds the mockgen binary from vendor directory func NewMockgenBin(h *PackageHelper, t *Template) (*MockgenBin, error) { return &MockgenBin{ pkgHelper: h, tmpl: t, }, nil } // GenMock generates mocks for given module instance, pkg is the package name of the generated mocks, // and intf is the interface name to generate mock for func (m MockgenBin) GenMock(importPath, pkg, intf string) ([]byte, error) { cmd := exec.Command("mockgen", "-package", pkg, importPath, intf) var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr if err := cmd.Run(); err != nil { return nil, errors.Wrapf( err, "error running command %q: %s", strings.Join(cmd.Args, " "), stderr.String(), ) } return stdout.Bytes(), nil } // byMethodName implements sort.Interface for []*modelMethod based on the Name field type byMethodName []*model.Method func (b byMethodName) Len() int { return len(b) } func (b byMethodName) Swap(i, j int) { b[i], b[j] = b[j], b[i] } func (b byMethodName) Less(i, j int) bool { return b[i].Name < b[j].Name } // AugmentMockWithFixture generates mocks with fixture for the interface in the given package func (m MockgenBin) AugmentMockWithFixture(pkg *model.Package, f *Fixture, intf string) ([]byte, []byte, error) { methodsMap := make(map[string]*model.Method, len(pkg.Interfaces[0].Methods)) validationMap := make(map[string]interface{}, len(pkg.Interfaces[0].Methods)) for _, m := range pkg.Interfaces[0].Methods { methodsMap[m.Name] = m validationMap[m.Name] = struct{}{} } if err := f.Validate(validationMap); err != nil { return nil, nil, errors.Wrap(err, "invalid fixture config") } exposedMethods := make([]*model.Method, 0, len(f.Scenarios)) for name := range f.Scenarios { exposedMethods = append(exposedMethods, methodsMap[name]) } // sort methods in given fixture config for predictable fixture type generation sort.Sort(byMethodName(exposedMethods)) pkgPathToAlias := uniqueAlias(pkg.Imports()) methods := make([]*reflectMethod, 0, len(exposedMethods)) for _, m := range exposedMethods { numIn := len(m.In) in := make(map[string]string, numIn) inString := make([]string, 0, numIn) for i, param := range m.In { arg := "arg" + strconv.Itoa(i) in[arg] = param.Type.String(pkgPathToAlias, "") inString = append(inString, arg) } numOut := len(m.Out) out := make(map[string]string, numOut) outString := make([]string, 0, numOut) for i, param := range m.Out { ret := "ret" + strconv.Itoa(i) out[ret] = param.Type.String(pkgPathToAlias, "") outString = append(outString, ret) } method := &reflectMethod{ Name: m.Name, In: in, Out: out, InString: strings.Join(inString, " ,"), OutString: strings.Join(outString, " ,"), } if m.Variadic != nil { method.Variadic = "arg" + strconv.Itoa(len(m.In)) method.VariadicType = m.Variadic.Type.String(pkgPathToAlias, "") } methods = append(methods, method) } data := map[string]interface{}{ "Imports": pkgPathToAlias, "Methods": methods, "Fixture": f, "ClientInterface": intf, } types, err := m.tmpl.ExecTemplate("fixture_types.tmpl", data, m.pkgHelper) if err != nil { return nil, nil, err } mock, err := m.tmpl.ExecTemplate("augmented_mock.tmpl", data, m.pkgHelper) if err != nil { return nil, nil, err } return types, mock, nil } type reflectMethod struct { Name string In, Out map[string]string Variadic string VariadicType string InString, OutString string } // uniqueAlias returns a map of import path to alias where the aliases are unique func uniqueAlias(importPaths map[string]bool) map[string]string { pkgPathToAlias := make(map[string]string, len(importPaths)) usedAliases := make(map[string]bool, len(importPaths)) for pkgPath := range importPaths { base := CamelCase(path.Base(pkgPath)) pkgAlias := base i := 0 for usedAliases[pkgAlias] || token.Lookup(pkgAlias).IsKeyword() { pkgAlias = base + strconv.Itoa(i) i++ } pkgPathToAlias[pkgPath] = pkgAlias usedAliases[pkgAlias] = true } return pkgPathToAlias }