github.com/rancher/moq@v0.0.0-20200712062324-13d1f37d2d77/pkg/moq/moq.go (about)

     1  package moq
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"go/build"
     8  	"go/types"
     9  	"io"
    10  	"os"
    11  	"path"
    12  	"path/filepath"
    13  	"strconv"
    14  	"strings"
    15  	"text/template"
    16  
    17  	"golang.org/x/tools/go/packages"
    18  )
    19  
    20  // Mocker can generate mock structs.
    21  type Mocker struct {
    22  	srcPkg  *packages.Package
    23  	tmpl    *template.Template
    24  	pkgName string
    25  	pkgPath string
    26  	fmter   func(src []byte) ([]byte, error)
    27  
    28  	importToAlias map[string]string
    29  	importAliases map[string]bool
    30  	importLines   map[string]bool
    31  }
    32  
    33  // Config specifies details about how interfaces should be mocked.
    34  // SrcDir is the only field which needs be specified.
    35  type Config struct {
    36  	SrcDir    string
    37  	PkgName   string
    38  	Formatter string
    39  }
    40  
    41  // New makes a new Mocker for the specified package directory.
    42  func New(conf Config) (*Mocker, error) {
    43  	srcPkg, err := pkgInfoFromPath(conf.SrcDir, packages.NeedName|packages.NeedTypes|packages.NeedTypesInfo|packages.NeedSyntax)
    44  	if err != nil {
    45  		return nil, fmt.Errorf("couldn't load source package: %s", err)
    46  	}
    47  
    48  	pkgName := conf.PkgName
    49  	if pkgName == "" {
    50  		pkgName = srcPkg.Name
    51  	}
    52  
    53  	pkgPath, err := findPkgPath(conf.PkgName, srcPkg)
    54  	if err != nil {
    55  		return nil, fmt.Errorf("couldn't load mock package: %s", err)
    56  	}
    57  
    58  	tmpl, err := template.New("moq").Funcs(templateFuncs).Parse(moqTemplate)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	fmter := gofmt
    64  	if conf.Formatter == "goimports" {
    65  		fmter = goimports
    66  	}
    67  
    68  	importAliases := make(map[string]bool)
    69  	importToAlias := make(map[string]string)
    70  
    71  	// Attempt to preserve original aliases for prettiness
    72  	for _, syntax := range srcPkg.Syntax {
    73  		for _, importSpec := range syntax.Imports {
    74  			if importSpec.Name != nil && importSpec.Path != nil {
    75  				importAliases[importSpec.Name.Name] = true
    76  				importToAlias[strings.Trim(importSpec.Path.Value, "\"")] = importSpec.Name.Name
    77  			}
    78  		}
    79  	}
    80  
    81  	return &Mocker{
    82  		tmpl:          tmpl,
    83  		srcPkg:        srcPkg,
    84  		pkgName:       pkgName,
    85  		pkgPath:       pkgPath,
    86  		fmter:         fmter,
    87  		importLines:   make(map[string]bool),
    88  		importAliases: importAliases,
    89  		importToAlias: importToAlias,
    90  	}, nil
    91  }
    92  
    93  func findPkgPath(pkgInputVal string, srcPkg *packages.Package) (string, error) {
    94  	if pkgInputVal == "" {
    95  		return srcPkg.PkgPath, nil
    96  	}
    97  	if pkgInDir(".", pkgInputVal) {
    98  		return ".", nil
    99  	}
   100  	if pkgInDir(srcPkg.PkgPath, pkgInputVal) {
   101  		return srcPkg.PkgPath, nil
   102  	}
   103  	subdirectoryPath := filepath.Join(srcPkg.PkgPath, pkgInputVal)
   104  	if pkgInDir(subdirectoryPath, pkgInputVal) {
   105  		return subdirectoryPath, nil
   106  	}
   107  	return "", nil
   108  }
   109  
   110  func pkgInDir(pkgName, dir string) bool {
   111  	currentPkg, err := pkgInfoFromPath(dir, packages.NeedName)
   112  	if err != nil {
   113  		return false
   114  	}
   115  	return currentPkg.Name == pkgName || currentPkg.Name+"_test" == pkgName
   116  }
   117  
   118  // Mock generates a mock for the specified interface name.
   119  func (m *Mocker) Mock(w io.Writer, names ...string) error {
   120  	if len(names) == 0 {
   121  		return errors.New("must specify one interface")
   122  	}
   123  
   124  	doc := doc{
   125  		PackageName: m.pkgName,
   126  	}
   127  
   128  	mocksMethods := false
   129  
   130  	paramCache := make(map[string][]*param)
   131  
   132  	tpkg := m.srcPkg.Types
   133  	for _, name := range names {
   134  		n, mockName := parseInterfaceName(name)
   135  		iface := tpkg.Scope().Lookup(n)
   136  		if iface == nil {
   137  			return fmt.Errorf("cannot find interface %s", n)
   138  		}
   139  		if !types.IsInterface(iface.Type()) {
   140  			return fmt.Errorf("%s (%s) not an interface", n, iface.Type().String())
   141  		}
   142  		iiface := iface.Type().Underlying().(*types.Interface).Complete()
   143  		obj := obj{
   144  			InterfaceName: n,
   145  			MockName:      mockName,
   146  		}
   147  		for i := 0; i < iiface.NumMethods(); i++ {
   148  			mocksMethods = true
   149  			meth := iiface.Method(i)
   150  			sig := meth.Type().(*types.Signature)
   151  			method := &method{
   152  				Name: meth.Name(),
   153  			}
   154  			obj.Methods = append(obj.Methods, method)
   155  			method.Params, method.Returns = m.extractArgs(sig)
   156  
   157  			for _, param := range method.Params {
   158  				paramCache[param.Name] = append(paramCache[param.Name], param)
   159  			}
   160  		}
   161  		doc.Objects = append(doc.Objects, obj)
   162  	}
   163  
   164  	if mocksMethods {
   165  		_, importLine := m.qualifierAndImportLine("sync", "sync")
   166  		doc.Imports = append(doc.Imports, importLine)
   167  	}
   168  
   169  	for pkgToImport := range m.importLines {
   170  		doc.Imports = append(doc.Imports, pkgToImport)
   171  	}
   172  
   173  	if tpkg.Name() != m.pkgName {
   174  		qualifier, importLine := m.qualifierAndImportLine(tpkg.Path(), tpkg.Name())
   175  		doc.SourcePackagePrefix = qualifier + "."
   176  		doc.Imports = append(doc.Imports, importLine)
   177  	}
   178  
   179  	for pkg := range m.importAliases {
   180  		if params, hasConflict := paramCache[pkg]; hasConflict {
   181  			for _, param := range params {
   182  				param.LocalName = fmt.Sprintf("%sMoqParam", param.LocalName)
   183  			}
   184  		}
   185  	}
   186  
   187  	var buf bytes.Buffer
   188  	err := m.tmpl.Execute(&buf, doc)
   189  	if err != nil {
   190  		return err
   191  	}
   192  	formatted, err := m.fmter(buf.Bytes())
   193  	if err != nil {
   194  		return err
   195  	}
   196  	if _, err := w.Write(formatted); err != nil {
   197  		return err
   198  	}
   199  	return nil
   200  }
   201  
   202  func (m *Mocker) allocAlias(path string, pkgName string) string {
   203  	suffix := 0
   204  	attemptedName := pkgName
   205  	for {
   206  		if _, taken := m.importAliases[attemptedName]; taken {
   207  			suffix++
   208  			attemptedName = fmt.Sprintf("%s%d", pkgName, suffix)
   209  			continue
   210  		}
   211  
   212  		m.importAliases[attemptedName] = true
   213  		m.importToAlias[path] = attemptedName
   214  
   215  		// Don't alias packages that don't require an alias
   216  		if attemptedName == pkgName {
   217  			m.importToAlias[path] = ""
   218  			return ""
   219  		}
   220  
   221  		return attemptedName
   222  	}
   223  }
   224  
   225  func (m *Mocker) getAlias(path string, pkgName string) string {
   226  	alias, aliasSet := m.importToAlias[path]
   227  	if !aliasSet {
   228  		alias = m.allocAlias(path, pkgName)
   229  	}
   230  	return alias
   231  }
   232  
   233  func (m *Mocker) qualifierAndImportLine(pkg, pkgName string) (string, string) {
   234  	pkg = stripVendorPath(pkg)
   235  	alias := m.getAlias(pkg, pkgName)
   236  	importLine := quoteImport(alias, pkg)
   237  	if alias == "" {
   238  		return pkgName, importLine
   239  	}
   240  	return alias, importLine
   241  }
   242  
   243  func quoteImport(alias, pkg string) string {
   244  	if alias == "" {
   245  		return fmt.Sprintf("\"%s\"", pkg)
   246  	}
   247  	return fmt.Sprintf("%s \"%s\"", alias, pkg)
   248  }
   249  
   250  func (m *Mocker) packageQualifier(pkg *types.Package) string {
   251  	if m.pkgPath != "" && m.pkgPath == pkg.Path() {
   252  		return ""
   253  	}
   254  	path := pkg.Path()
   255  	if pkg.Path() == "." {
   256  		wd, err := os.Getwd()
   257  		if err == nil {
   258  			path = stripGopath(wd)
   259  		}
   260  	}
   261  
   262  	qualifier, importLine := m.qualifierAndImportLine(path, pkg.Name())
   263  	m.importLines[importLine] = true
   264  	return qualifier
   265  }
   266  
   267  func (m *Mocker) extractArgs(sig *types.Signature) (params, results []*param) {
   268  	pp := sig.Params()
   269  	for i := 0; i < pp.Len(); i++ {
   270  		p := m.buildParam(pp.At(i), "in"+strconv.Itoa(i+1))
   271  		// check for final variadic argument
   272  		p.Variadic = sig.Variadic() && i == pp.Len()-1 && p.Type[0:2] == "[]"
   273  		params = append(params, p)
   274  	}
   275  
   276  	rr := sig.Results()
   277  	for i := 0; i < rr.Len(); i++ {
   278  		results = append(results, m.buildParam(rr.At(i), "out"+strconv.Itoa(i+1)))
   279  	}
   280  
   281  	return
   282  }
   283  
   284  func (m *Mocker) buildParam(v *types.Var, fallbackName string) *param {
   285  	name := v.Name()
   286  	if name == "" {
   287  		name = fallbackName
   288  	}
   289  	typ := types.TypeString(v.Type(), m.packageQualifier)
   290  	return &param{Name: name, LocalName: name, Type: typ}
   291  }
   292  
   293  func pkgInfoFromPath(srcDir string, mode packages.LoadMode) (*packages.Package, error) {
   294  	pkgs, err := packages.Load(&packages.Config{
   295  		Mode: mode,
   296  		Dir:  srcDir,
   297  	})
   298  	if err != nil {
   299  		return nil, err
   300  	}
   301  	if len(pkgs) == 0 {
   302  		return nil, errors.New("No packages found")
   303  	}
   304  	if len(pkgs) > 1 {
   305  		return nil, errors.New("More than one package was found")
   306  	}
   307  	return pkgs[0], nil
   308  }
   309  
   310  func parseInterfaceName(name string) (ifaceName, mockName string) {
   311  	parts := strings.SplitN(name, ":", 2)
   312  	ifaceName = parts[0]
   313  	mockName = ifaceName + "Mock"
   314  	if len(parts) == 2 {
   315  		mockName = parts[1]
   316  	}
   317  	return
   318  }
   319  
   320  type doc struct {
   321  	PackageName         string
   322  	SourcePackagePrefix string
   323  	Objects             []obj
   324  	Imports             []string
   325  }
   326  
   327  type obj struct {
   328  	InterfaceName string
   329  	MockName      string
   330  	Methods       []*method
   331  }
   332  type method struct {
   333  	Name    string
   334  	Params  []*param
   335  	Returns []*param
   336  }
   337  
   338  func (m *method) Arglist() string {
   339  	params := make([]string, len(m.Params))
   340  	for i, p := range m.Params {
   341  		params[i] = p.String()
   342  	}
   343  	return strings.Join(params, ", ")
   344  }
   345  
   346  func (m *method) ArgCallList() string {
   347  	params := make([]string, len(m.Params))
   348  	for i, p := range m.Params {
   349  		params[i] = p.CallName()
   350  	}
   351  	return strings.Join(params, ", ")
   352  }
   353  
   354  func (m *method) ReturnArglist() string {
   355  	params := make([]string, len(m.Returns))
   356  	for i, p := range m.Returns {
   357  		params[i] = p.TypeString()
   358  	}
   359  	if len(m.Returns) > 1 {
   360  		return fmt.Sprintf("(%s)", strings.Join(params, ", "))
   361  	}
   362  	return strings.Join(params, ", ")
   363  }
   364  
   365  type param struct {
   366  	Name      string
   367  	LocalName string
   368  	Type      string
   369  	Variadic  bool
   370  }
   371  
   372  func (p param) String() string {
   373  	return fmt.Sprintf("%s %s", p.LocalName, p.TypeString())
   374  }
   375  
   376  func (p param) CallName() string {
   377  	if p.Variadic {
   378  		return p.LocalName + "..."
   379  	}
   380  	return p.LocalName
   381  }
   382  
   383  func (p param) TypeString() string {
   384  	if p.Variadic {
   385  		return "..." + p.Type[2:]
   386  	}
   387  	return p.Type
   388  }
   389  
   390  var templateFuncs = template.FuncMap{
   391  	"Exported": func(s string) string {
   392  		if s == "" {
   393  			return ""
   394  		}
   395  		for _, initialism := range golintInitialisms {
   396  			if strings.ToUpper(s) == initialism {
   397  				return initialism
   398  			}
   399  		}
   400  		return strings.ToUpper(s[0:1]) + s[1:]
   401  	},
   402  }
   403  
   404  // stripVendorPath strips the vendor dir prefix from a package path.
   405  // For example we might encounter an absolute path like
   406  // github.com/foo/bar/vendor/github.com/pkg/errors which is resolved
   407  // to github.com/pkg/errors.
   408  func stripVendorPath(p string) string {
   409  	parts := strings.Split(p, "/vendor/")
   410  	if len(parts) == 1 {
   411  		return p
   412  	}
   413  	return strings.TrimLeft(path.Join(parts[1:]...), "/")
   414  }
   415  
   416  // stripGopath takes the directory to a package and removes the
   417  // $GOPATH/src path to get the canonical package name.
   418  func stripGopath(p string) string {
   419  	for _, srcDir := range build.Default.SrcDirs() {
   420  		rel, err := filepath.Rel(srcDir, p)
   421  		if err != nil || strings.HasPrefix(rel, "..") {
   422  			continue
   423  		}
   424  		return filepath.ToSlash(rel)
   425  	}
   426  	return p
   427  }