github.com/relnod/pegomock@v2.0.1+incompatible/mockgen/mockgen.go (about)

     1  // Copyright 2015 Peter Goetz
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Based on the work done in
    16  // https://github.com/golang/mock/blob/d581abfc04272f381d7a05e4b80163ea4e2b9447/mockgen/mockgen.go
    17  
    18  // MockGen generates mock implementations of Go interfaces.
    19  package mockgen
    20  
    21  // TODO: This does not support recursive embedded interfaces.
    22  // TODO: This does not support embedding package-local interfaces in a separate file.
    23  
    24  import (
    25  	"bytes"
    26  	"fmt"
    27  	"go/format"
    28  	"go/token"
    29  	"path"
    30  	"strconv"
    31  	"strings"
    32  	"unicode"
    33  
    34  	"github.com/petergtz/pegomock/model"
    35  )
    36  
    37  const mockFrameworkImportPath = "github.com/petergtz/pegomock"
    38  
    39  func GenerateOutput(ast *model.Package, source, packageOut, selfPackage string) ([]byte, map[string]string) {
    40  	g := generator{typesSet: make(map[string]string)}
    41  	g.generateCode(source, ast, packageOut, selfPackage)
    42  	return g.formattedOutput(), g.typesSet
    43  }
    44  
    45  type generator struct {
    46  	buf        bytes.Buffer
    47  	packageMap map[string]string // map from import path to package name
    48  	typesSet   map[string]string
    49  }
    50  
    51  func (g *generator) generateCode(source string, pkg *model.Package, pkgName, selfPackage string) {
    52  	g.p("// Code generated by pegomock. DO NOT EDIT.")
    53  	g.p("// Source: %v", source)
    54  	g.emptyLine()
    55  
    56  	importPaths := pkg.Imports()
    57  	importPaths[mockFrameworkImportPath] = true
    58  	packageMap, nonVendorPackageMap := generateUniquePackageNamesFor(importPaths)
    59  	g.packageMap = packageMap
    60  
    61  	g.p("package %v", pkgName)
    62  	g.emptyLine()
    63  	g.p("import (")
    64  	g.p("\"reflect\"")
    65  	g.p("\"time\"")
    66  	for packagePath, packageName := range nonVendorPackageMap {
    67  		if packagePath != selfPackage && packagePath != "time" && packagePath != "reflect" {
    68  			g.p("%v %q", packageName, packagePath)
    69  		}
    70  	}
    71  	for _, packagePath := range pkg.DotImports {
    72  		g.p(". %q", packagePath)
    73  	}
    74  	g.p(")")
    75  
    76  	for _, iface := range pkg.Interfaces {
    77  		g.generateMockFor(iface, selfPackage)
    78  	}
    79  }
    80  
    81  func generateUniquePackageNamesFor(importPaths map[string]bool) (packageMap, nonVendorPackageMap map[string]string) {
    82  	packageMap = make(map[string]string, len(importPaths))
    83  	nonVendorPackageMap = make(map[string]string, len(importPaths))
    84  	packageNamesAlreadyUsed := make(map[string]bool, len(importPaths))
    85  	for importPath := range importPaths {
    86  		sanitizedPackagePathBaseName := sanitize(path.Base(importPath))
    87  
    88  		// Local names for an imported package can usually be the basename of the import path.
    89  		// A couple of situations don't permit that, such as duplicate local names
    90  		// (e.g. importing "html/template" and "text/template"), or where the basename is
    91  		// a keyword (e.g. "foo/case").
    92  		// try base0, base1, ...
    93  		packageName := sanitizedPackagePathBaseName
    94  		for i := 0; packageNamesAlreadyUsed[packageName] || token.Lookup(packageName).IsKeyword(); i++ {
    95  			packageName = sanitizedPackagePathBaseName + strconv.Itoa(i)
    96  		}
    97  
    98  		packageMap[importPath] = packageName
    99  		packageNamesAlreadyUsed[packageName] = true
   100  
   101  		nonVendorPackageMap[vendorCleaned(importPath)] = packageName
   102  	}
   103  	return
   104  }
   105  
   106  func vendorCleaned(importPath string) string {
   107  	if split := strings.Split(importPath, "/vendor/"); len(split) > 1 {
   108  		return split[1]
   109  	}
   110  	return importPath
   111  }
   112  
   113  // sanitize cleans up a string to make a suitable package name.
   114  // pkgName in reflect mode is the base name of the import path,
   115  // which might have characters that are illegal to have in package names.
   116  func sanitize(s string) string {
   117  	t := ""
   118  	for _, r := range s {
   119  		if t == "" {
   120  			if unicode.IsLetter(r) || r == '_' {
   121  				t += string(r)
   122  				continue
   123  			}
   124  		} else {
   125  			if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
   126  				t += string(r)
   127  				continue
   128  			}
   129  		}
   130  		t += "_"
   131  	}
   132  	if t == "_" {
   133  		t = "x"
   134  	}
   135  	return t
   136  }
   137  
   138  func (g *generator) generateMockFor(iface *model.Interface, selfPackage string) {
   139  	mockTypeName := "Mock" + iface.Name
   140  	g.generateMockType(mockTypeName)
   141  	for _, method := range iface.Methods {
   142  		g.generateMockMethod(mockTypeName, method, selfPackage)
   143  		g.emptyLine()
   144  
   145  		addTypesFromMethodParamsTo(g.typesSet, method.In, g.packageMap)
   146  		addTypesFromMethodParamsTo(g.typesSet, method.Out, g.packageMap)
   147  	}
   148  	g.generateMockVerifyMethods(iface.Name)
   149  	g.generateVerifierType(iface.Name)
   150  	for _, method := range iface.Methods {
   151  		ongoingVerificationTypeName := fmt.Sprintf("%v_%v_OngoingVerification", iface.Name, method.Name)
   152  		args, argNames, argTypes, _ := argDataFor(method, g.packageMap, selfPackage)
   153  		g.generateVerifierMethod(iface.Name, method, selfPackage, ongoingVerificationTypeName, args, argNames)
   154  		g.generateOngoingVerificationType(iface.Name, ongoingVerificationTypeName)
   155  		g.generateOngoingVerificationGetCapturedArguments(ongoingVerificationTypeName, argNames, argTypes)
   156  		g.generateOngoingVerificationGetAllCapturedArguments(ongoingVerificationTypeName, argTypes, method.Variadic != nil)
   157  	}
   158  }
   159  
   160  func (g *generator) generateMockType(mockTypeName string) {
   161  	g.
   162  		emptyLine().
   163  		p("type %v struct {", mockTypeName).
   164  		p("	fail func(message string, callerSkip ...int)").
   165  		p("}").
   166  		emptyLine().
   167  		p("func New%v(options ...pegomock.Option) *%v {", mockTypeName, mockTypeName).
   168  		p("	mock := &%v{}", mockTypeName).
   169  		p("	for _, option := range options {").
   170  		p("		option.Apply(mock)").
   171  		p("	}").
   172  		p("	return mock").
   173  		p("}").
   174  		emptyLine().
   175  		p("func (mock *%v) SetFailHandler(fh pegomock.FailHandler) { mock.fail = fh }", mockTypeName).
   176  		p("func (mock *%v) FailHandler() pegomock.FailHandler      { return mock.fail }", mockTypeName).
   177  		emptyLine()
   178  }
   179  
   180  // If non-empty, pkgOverride is the package in which unqualified types reside.
   181  func (g *generator) generateMockMethod(mockType string, method *model.Method, pkgOverride string) *generator {
   182  	args, argNames, _, returnTypes := argDataFor(method, g.packageMap, pkgOverride)
   183  	g.p("func (mock *%v) %v(%v) (%v) {", mockType, method.Name, join(args), join(returnTypes))
   184  	g.p("if mock == nil {").
   185  		p("	panic(\"mock must not be nil. Use myMock := New%v().\")", mockType).
   186  		p("}")
   187  	g.GenerateParamsDeclaration(argNames, method.Variadic != nil)
   188  	reflectReturnTypes := make([]string, len(returnTypes))
   189  	for i, returnType := range returnTypes {
   190  		reflectReturnTypes[i] = fmt.Sprintf("reflect.TypeOf((*%v)(nil)).Elem()", returnType)
   191  	}
   192  	resultAssignment := ""
   193  	if len(method.Out) > 0 {
   194  		resultAssignment = "result :="
   195  	}
   196  	g.p("%v pegomock.GetGenericMockFrom(mock).Invoke(\"%v\", params, []reflect.Type{%v})",
   197  		resultAssignment, method.Name, strings.Join(reflectReturnTypes, ", "))
   198  	if len(method.Out) > 0 {
   199  		// TODO: translate LastInvocation into a Matcher so it can be used as key for Stubbings
   200  		for i, returnType := range returnTypes {
   201  			g.p("var ret%v %v", i, returnType)
   202  		}
   203  		g.p("if len(result) != 0 {")
   204  		returnValues := make([]string, len(returnTypes))
   205  		for i, returnType := range returnTypes {
   206  			g.p("if result[%v] != nil {", i)
   207  			g.p("ret%v  = result[%v].(%v)", i, i, returnType)
   208  			g.p("}")
   209  			returnValues[i] = fmt.Sprintf("ret%v", i)
   210  		}
   211  		g.p("}")
   212  		g.p("return %v", strings.Join(returnValues, ", "))
   213  	}
   214  	g.p("}")
   215  	return g
   216  }
   217  
   218  func (g *generator) generateVerifierType(interfaceName string) *generator {
   219  	return g.
   220  		p("type Verifier%v struct {", interfaceName).
   221  		p("	mock *Mock%v", interfaceName).
   222  		p("	invocationCountMatcher pegomock.Matcher").
   223  		p("	inOrderContext *pegomock.InOrderContext").
   224  		p("	timeout time.Duration").
   225  		p("}").
   226  		emptyLine()
   227  }
   228  
   229  func (g *generator) generateMockVerifyMethods(interfaceName string) {
   230  	g.
   231  		p("func (mock *Mock%v) VerifyWasCalledOnce() *Verifier%v {", interfaceName, interfaceName).
   232  		p("	return &Verifier%v{", interfaceName).
   233  		p("		mock: mock,").
   234  		p("		invocationCountMatcher: pegomock.Times(1),").
   235  		p("	}").
   236  		p("}").
   237  		emptyLine().
   238  		p("func (mock *Mock%v) VerifyWasCalled(invocationCountMatcher pegomock.Matcher) *Verifier%v {", interfaceName, interfaceName).
   239  		p("	return &Verifier%v{", interfaceName).
   240  		p("		mock: mock,").
   241  		p("		invocationCountMatcher: invocationCountMatcher,").
   242  		p("	}").
   243  		p("}").
   244  		emptyLine().
   245  		p("func (mock *Mock%v) VerifyWasCalledInOrder(invocationCountMatcher pegomock.Matcher, inOrderContext *pegomock.InOrderContext) *Verifier%v {", interfaceName, interfaceName).
   246  		p("	return &Verifier%v{", interfaceName).
   247  		p("		mock: mock,").
   248  		p("		invocationCountMatcher: invocationCountMatcher,").
   249  		p("		inOrderContext: inOrderContext,").
   250  		p("	}").
   251  		p("}").
   252  		emptyLine().
   253  		p("func (mock *Mock%v) VerifyWasCalledEventually(invocationCountMatcher pegomock.Matcher, timeout time.Duration) *Verifier%v {", interfaceName, interfaceName).
   254  		p("	return &Verifier%v{", interfaceName).
   255  		p("		mock: mock,").
   256  		p("		invocationCountMatcher: invocationCountMatcher,").
   257  		p("		timeout: timeout,").
   258  		p("	}").
   259  		p("}").
   260  		emptyLine()
   261  }
   262  
   263  func (g *generator) generateVerifierMethod(interfaceName string, method *model.Method, pkgOverride string, returnTypeString string, args []string, argNames []string) *generator {
   264  	return g.
   265  		p("func (verifier *Verifier%v) %v(%v) *%v {", interfaceName, method.Name, join(args), returnTypeString).
   266  		GenerateParamsDeclaration(argNames, method.Variadic != nil).
   267  		p("methodInvocations := pegomock.GetGenericMockFrom(verifier.mock).Verify(verifier.inOrderContext, verifier.invocationCountMatcher, \"%v\", params, verifier.timeout)", method.Name).
   268  		p("return &%v{mock: verifier.mock, methodInvocations: methodInvocations}", returnTypeString).
   269  		p("}")
   270  }
   271  
   272  func (g *generator) GenerateParamsDeclaration(argNames []string, isVariadic bool) *generator {
   273  	if isVariadic {
   274  		return g.
   275  			p("params := []pegomock.Param{%v}", strings.Join(argNames[0:len(argNames)-1], ", ")).
   276  			p("for _, param := range %v {", argNames[len(argNames)-1]).
   277  			p("params = append(params, param)").
   278  			p("}")
   279  	} else {
   280  		return g.p("params := []pegomock.Param{%v}", join(argNames))
   281  	}
   282  }
   283  
   284  func (g *generator) generateOngoingVerificationType(interfaceName string, ongoingVerificationStructName string) *generator {
   285  	return g.
   286  		p("type %v struct {", ongoingVerificationStructName).
   287  		p("mock *Mock%v", interfaceName).
   288  		p("	methodInvocations []pegomock.MethodInvocation").
   289  		p("}").
   290  		emptyLine()
   291  }
   292  
   293  func (g *generator) generateOngoingVerificationGetCapturedArguments(ongoingVerificationStructName string, argNames []string, argTypes []string) *generator {
   294  	g.p("func (c *%v) GetCapturedArguments() (%v) {", ongoingVerificationStructName, join(argTypes))
   295  	if len(argNames) > 0 {
   296  		indexedArgNames := make([]string, len(argNames))
   297  		for i, argName := range argNames {
   298  			indexedArgNames[i] = argName + "[len(" + argName + ")-1]"
   299  		}
   300  		g.p("%v := c.GetAllCapturedArguments()", join(argNames))
   301  		g.p("return %v", strings.Join(indexedArgNames, ", "))
   302  	}
   303  	g.p("}")
   304  	g.emptyLine()
   305  	return g
   306  }
   307  
   308  func (g *generator) generateOngoingVerificationGetAllCapturedArguments(ongoingVerificationStructName string, argTypes []string, isVariadic bool) *generator {
   309  	argsAsArray := make([]string, len(argTypes))
   310  	for i, argType := range argTypes {
   311  		argsAsArray[i] = fmt.Sprintf("_param%v []%v", i, argType)
   312  	}
   313  	g.p("func (c *%v) GetAllCapturedArguments() (%v) {", ongoingVerificationStructName, strings.Join(argsAsArray, ", "))
   314  	if len(argTypes) > 0 {
   315  		g.p("params := pegomock.GetGenericMockFrom(c.mock).GetInvocationParams(c.methodInvocations)")
   316  		g.p("if len(params) > 0 {")
   317  		for i, argType := range argTypes {
   318  			if isVariadic && i == len(argTypes)-1 {
   319  				variadicBasicType := strings.Replace(argType, "[]", "", 1)
   320  				g.
   321  					p("_param%v = make([]%v, len(params[%v]))", i, argType, i).
   322  					p("for u := range params[0] {"). // the number of invocations and hence len(params[x]) is equal for all x
   323  					p("_param%v[u] = make([]%v, len(params)-%v)", i, variadicBasicType, i).
   324  					p("for x := %v; x < len(params); x++ {", i).
   325  					p("if params[x][u] != nil {").
   326  					p("_param%v[u][x-%v] = params[x][u].(%v)", i, i, variadicBasicType).
   327  					p("}").
   328  					p("}").
   329  					p("}")
   330  				break
   331  			} else {
   332  				g.p("_param%v = make([]%v, len(params[%v]))", i, argType, i)
   333  				g.p("for u, param := range params[%v] {", i)
   334  				g.p("_param%v[u]=param.(%v)", i, argType)
   335  				g.p("}")
   336  			}
   337  		}
   338  		g.p("}")
   339  		g.p("return")
   340  	}
   341  	g.p("}")
   342  	g.emptyLine()
   343  	return g
   344  }
   345  
   346  func argDataFor(method *model.Method, packageMap map[string]string, pkgOverride string) (
   347  	args []string,
   348  	argNames []string,
   349  	argTypes []string,
   350  	returnTypes []string,
   351  ) {
   352  	args = make([]string, len(method.In))
   353  	argNames = make([]string, len(method.In))
   354  	argTypes = make([]string, len(args))
   355  	for i, arg := range method.In {
   356  		argName := arg.Name
   357  		if argName == "" {
   358  			argName = fmt.Sprintf("_param%d", i)
   359  		}
   360  		argType := arg.Type.String(packageMap, pkgOverride)
   361  		args[i] = argName + " " + argType
   362  		argNames[i] = argName
   363  		argTypes[i] = argType
   364  	}
   365  	if method.Variadic != nil {
   366  		argName := method.Variadic.Name
   367  		if argName == "" {
   368  			argName = fmt.Sprintf("_param%d", len(method.In))
   369  		}
   370  		argType := method.Variadic.Type.String(packageMap, pkgOverride)
   371  		args = append(args, argName+" ..."+argType)
   372  		argNames = append(argNames, argName)
   373  		argTypes = append(argTypes, "[]"+argType)
   374  	}
   375  	returnTypes = make([]string, len(method.Out))
   376  	for i, ret := range method.Out {
   377  		returnTypes[i] = ret.Type.String(packageMap, pkgOverride)
   378  	}
   379  	return
   380  }
   381  
   382  func addTypesFromMethodParamsTo(typesSet map[string]string, params []*model.Parameter, packageMap map[string]string) {
   383  	for _, param := range params {
   384  		switch typedType := param.Type.(type) {
   385  		case *model.NamedType, *model.PointerType, *model.ArrayType, *model.MapType, *model.ChanType:
   386  			if _, exists := typesSet[underscoreNameFor(typedType, packageMap)]; !exists {
   387  				typesSet[underscoreNameFor(typedType, packageMap)] = generateMatcherSourceCode(typedType, packageMap)
   388  			}
   389  		case *model.FuncType:
   390  			// matcher generation for funcs not supported yet
   391  			// TODO implement
   392  		case model.PredeclaredType:
   393  			// skip. These come as part of pegomock.
   394  		default:
   395  			panic("Should not get here")
   396  		}
   397  	}
   398  }
   399  
   400  func generateMatcherSourceCode(t model.Type, packageMap map[string]string) string {
   401  	return fmt.Sprintf(`// Code generated by pegomock. DO NOT EDIT.
   402  package matchers
   403  
   404  import (
   405  	"reflect"
   406  	"github.com/petergtz/pegomock"
   407  	%v
   408  )
   409  
   410  func Any%v() %v {
   411  	pegomock.RegisterMatcher(pegomock.NewAnyMatcher(reflect.TypeOf((*(%v))(nil)).Elem()))
   412  	var nullValue %v
   413  	return nullValue
   414  }
   415  
   416  func Eq%v(value %v) %v {
   417  	pegomock.RegisterMatcher(&pegomock.EqMatcher{Value: value})
   418  	var nullValue %v
   419  	return nullValue
   420  }
   421  `,
   422  		optionalPackageOf(t, packageMap),
   423  		camelcaseNameFor(t, packageMap),
   424  		t.String(packageMap, ""),
   425  		t.String(packageMap, ""),
   426  		t.String(packageMap, ""),
   427  
   428  		camelcaseNameFor(t, packageMap),
   429  		t.String(packageMap, ""),
   430  		t.String(packageMap, ""),
   431  		t.String(packageMap, ""),
   432  	)
   433  }
   434  
   435  func optionalPackageOf(t model.Type, packageMap map[string]string) string {
   436  	switch typedType := t.(type) {
   437  	case model.PredeclaredType:
   438  		return ""
   439  	case *model.NamedType:
   440  		return fmt.Sprintf("%v \"%v\"", packageMap[typedType.Package], vendorCleaned(typedType.Package))
   441  	case *model.PointerType:
   442  		return optionalPackageOf(typedType.Type, packageMap)
   443  	case *model.ArrayType:
   444  		return optionalPackageOf(typedType.Type, packageMap)
   445  	case *model.MapType:
   446  		return optionalPackageOf(typedType.Key, packageMap) + "\n" + optionalPackageOf(typedType.Value, packageMap)
   447  	case *model.ChanType:
   448  		return optionalPackageOf(typedType.Type, packageMap)
   449  		// TODO:
   450  	// case *model.FuncType:
   451  	default:
   452  		panic(fmt.Sprintf("TODO implement optionalPackageOf for: %v\nis type of %T\n", typedType, typedType))
   453  	}
   454  }
   455  
   456  func spaceSeparatedNameFor(t model.Type, packageMap map[string]string) string {
   457  	switch typedType := t.(type) {
   458  	case model.PredeclaredType:
   459  		tt := typedType.String(packageMap, "")
   460  		if tt == "interface{}" {
   461  			// if a predeclared type is interface
   462  			// return a string type without curly brackets
   463  			return "interface"
   464  		}
   465  		return tt
   466  	case *model.NamedType:
   467  		return strings.Replace((typedType.String(packageMap, "")), ".", " ", -1)
   468  	case *model.PointerType:
   469  		return "ptr to " + spaceSeparatedNameFor(typedType.Type, packageMap)
   470  	case *model.ArrayType:
   471  		if typedType.Len == -1 {
   472  			return "slice of " + spaceSeparatedNameFor(typedType.Type, packageMap)
   473  		} else {
   474  			return "array of " + spaceSeparatedNameFor(typedType.Type, packageMap)
   475  		}
   476  	case *model.MapType:
   477  		return "map of " + spaceSeparatedNameFor(typedType.Key, packageMap) + " to " + spaceSeparatedNameFor(typedType.Value, packageMap)
   478  	case *model.ChanType:
   479  		return "chan of " + spaceSeparatedNameFor(typedType.Type, packageMap)
   480  	// TODO:
   481  	// case *model.FuncType:
   482  	default:
   483  		return fmt.Sprintf("TODO implement matcher for: %v\nis type of %T\n", typedType, typedType)
   484  	}
   485  }
   486  
   487  func camelcaseNameFor(t model.Type, packageMap map[string]string) string {
   488  	return strings.Replace(strings.Title(strings.Replace(spaceSeparatedNameFor(t, packageMap), "_", " ", -1)), " ", "", -1)
   489  }
   490  
   491  func underscoreNameFor(t model.Type, packageMap map[string]string) string {
   492  	return strings.ToLower(strings.Replace(spaceSeparatedNameFor(t, packageMap), " ", "_", -1))
   493  }
   494  
   495  func (g *generator) p(format string, args ...interface{}) *generator {
   496  	fmt.Fprintf(&g.buf, format+"\n", args...)
   497  	return g
   498  }
   499  
   500  func (g *generator) emptyLine() *generator { return g.p("") }
   501  
   502  func (g *generator) formattedOutput() []byte {
   503  	src, err := format.Source(g.buf.Bytes())
   504  	if err != nil {
   505  		panic(fmt.Errorf("Failed to format generated source code: %s\n%s", err, g.buf.String()))
   506  	}
   507  	return src
   508  }
   509  
   510  func join(s []string) string { return strings.Join(s, ", ") }