github.com/petergtz/pegomock@v2.9.1-0.20230424204322-eb0e044013df+incompatible/dsl.go (about)

     1  // Copyright 2015 Peter Goetz
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package pegomock
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"reflect"
    21  	"sort"
    22  	"sync"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/onsi/gomega/format"
    27  	"github.com/petergtz/pegomock/internal/verify"
    28  )
    29  
    30  var GlobalFailHandler FailHandler
    31  
    32  func RegisterMockFailHandler(handler FailHandler) {
    33  	GlobalFailHandler = handler
    34  }
    35  func RegisterMockTestingT(t *testing.T) {
    36  	RegisterMockFailHandler(BuildTestingTFailHandler(t))
    37  }
    38  
    39  var (
    40  	lastInvocation      *invocation
    41  	lastInvocationMutex sync.Mutex
    42  )
    43  
    44  var globalArgMatchers Matchers
    45  
    46  func RegisterMatcher(matcher ArgumentMatcher) {
    47  	globalArgMatchers.append(matcher)
    48  }
    49  
    50  type invocation struct {
    51  	genericMock *GenericMock
    52  	MethodName  string
    53  	Params      []Param
    54  	ReturnTypes []reflect.Type
    55  }
    56  
    57  type GenericMock struct {
    58  	sync.Mutex
    59  	mockedMethods map[string]*mockedMethod
    60  	fail          FailHandler
    61  }
    62  
    63  func (genericMock *GenericMock) Invoke(methodName string, params []Param, returnTypes []reflect.Type) ReturnValues {
    64  	lastInvocationMutex.Lock()
    65  	lastInvocation = &invocation{
    66  		genericMock: genericMock,
    67  		MethodName:  methodName,
    68  		Params:      params,
    69  		ReturnTypes: returnTypes,
    70  	}
    71  	lastInvocationMutex.Unlock()
    72  	return genericMock.getOrCreateMockedMethod(methodName).Invoke(params)
    73  }
    74  
    75  func (genericMock *GenericMock) stub(methodName string, paramMatchers []ArgumentMatcher, returnValues ReturnValues) {
    76  	genericMock.stubWithCallback(methodName, paramMatchers, func([]Param) ReturnValues { return returnValues })
    77  }
    78  
    79  func (genericMock *GenericMock) stubWithCallback(methodName string, paramMatchers []ArgumentMatcher, callback func([]Param) ReturnValues) {
    80  	genericMock.getOrCreateMockedMethod(methodName).stub(paramMatchers, callback)
    81  }
    82  
    83  func (genericMock *GenericMock) getOrCreateMockedMethod(methodName string) *mockedMethod {
    84  	genericMock.Lock()
    85  	defer genericMock.Unlock()
    86  	if _, ok := genericMock.mockedMethods[methodName]; !ok {
    87  		genericMock.mockedMethods[methodName] = &mockedMethod{name: methodName}
    88  	}
    89  	return genericMock.mockedMethods[methodName]
    90  }
    91  
    92  func (genericMock *GenericMock) reset(methodName string, paramMatchers []ArgumentMatcher) {
    93  	genericMock.getOrCreateMockedMethod(methodName).reset(paramMatchers)
    94  }
    95  
    96  func (genericMock *GenericMock) Verify(
    97  	inOrderContext *InOrderContext,
    98  	invocationCountMatcher InvocationCountMatcher,
    99  	methodName string,
   100  	params []Param,
   101  	options ...interface{},
   102  ) []MethodInvocation {
   103  	var timeout time.Duration
   104  	if len(options) == 1 {
   105  		timeout = options[0].(time.Duration)
   106  	}
   107  	if genericMock.fail == nil && GlobalFailHandler == nil {
   108  		panic("No FailHandler set. Please use either RegisterMockFailHandler or RegisterMockTestingT or TODO to set a fail handler.")
   109  	}
   110  	fail := GlobalFailHandler
   111  	if genericMock.fail != nil {
   112  		fail = genericMock.fail
   113  	}
   114  	defer func() { globalArgMatchers = nil }() // We don't want a panic somewhere during verification screw our global argMatchers
   115  
   116  	if len(globalArgMatchers) != 0 {
   117  		verifyArgMatcherUse(globalArgMatchers, params)
   118  	}
   119  	startTime := time.Now()
   120  	// timeoutLoop:
   121  	for {
   122  		genericMock.Lock()
   123  		methodInvocations := genericMock.methodInvocations(methodName, params, globalArgMatchers)
   124  		genericMock.Unlock()
   125  		if inOrderContext != nil {
   126  			for _, methodInvocation := range methodInvocations {
   127  				if methodInvocation.orderingInvocationNumber <= inOrderContext.invocationCounter {
   128  					// TODO: should introduce the following, in case we decide support "inorder" and "eventually"
   129  					// if time.Since(startTime) < timeout {
   130  					// 	continue timeoutLoop
   131  					// }
   132  					fail(fmt.Sprintf("Expected function call %v(%v) before function call %v(%v)",
   133  						methodName, formatParams(params), inOrderContext.lastInvokedMethodName, formatParams(inOrderContext.lastInvokedMethodParams)))
   134  				}
   135  				inOrderContext.invocationCounter = methodInvocation.orderingInvocationNumber
   136  				inOrderContext.lastInvokedMethodName = methodName
   137  				inOrderContext.lastInvokedMethodParams = params
   138  			}
   139  		}
   140  		if !invocationCountMatcher.Matches(len(methodInvocations)) {
   141  			if time.Since(startTime) < timeout {
   142  				time.Sleep(10 * time.Millisecond)
   143  				continue
   144  			}
   145  			var paramsOrMatchers interface{} = formatParams(params)
   146  			if len(globalArgMatchers) != 0 {
   147  				paramsOrMatchers = formatMatchers(globalArgMatchers)
   148  			}
   149  			timeoutInfo := ""
   150  			if timeout > 0 {
   151  				timeoutInfo = fmt.Sprintf(" after timeout of %v", timeout)
   152  			}
   153  			fail(fmt.Sprintf(
   154  				"Mock invocation count for %v(%v) does not match expectation%v.\n\n\t%v\n\n\t%v",
   155  				methodName, paramsOrMatchers, timeoutInfo, invocationCountMatcher.FailureMessage(), formatInteractions(genericMock.allInteractions())))
   156  		}
   157  		return methodInvocations
   158  	}
   159  }
   160  
   161  // TODO this doesn't need to be a method, can be a free function
   162  func (genericMock *GenericMock) GetInvocationParams(methodInvocations []MethodInvocation) [][]Param {
   163  	if len(methodInvocations) == 0 {
   164  		return nil
   165  	}
   166  	result := make([][]Param, len(methodInvocations[len(methodInvocations)-1].params))
   167  	for i, invocation := range methodInvocations {
   168  		for u, param := range invocation.params {
   169  			if result[u] == nil {
   170  				result[u] = make([]Param, len(methodInvocations))
   171  			}
   172  			result[u][i] = param
   173  		}
   174  	}
   175  	return result
   176  }
   177  
   178  func (genericMock *GenericMock) methodInvocations(methodName string, params []Param, matchers []ArgumentMatcher) []MethodInvocation {
   179  	var invocations []MethodInvocation
   180  	if method, exists := genericMock.mockedMethods[methodName]; exists {
   181  		method.Lock()
   182  		for _, invocation := range method.invocations {
   183  			if len(matchers) != 0 {
   184  				if Matchers(matchers).Matches(invocation.params) {
   185  					invocations = append(invocations, invocation)
   186  				}
   187  			} else {
   188  				if reflect.DeepEqual(params, invocation.params) ||
   189  					(len(params) == 0 && len(invocation.params) == 0) {
   190  					invocations = append(invocations, invocation)
   191  				}
   192  			}
   193  		}
   194  		method.Unlock()
   195  	}
   196  	return invocations
   197  }
   198  
   199  func formatInteractions(interactions map[string][]MethodInvocation) string {
   200  	if len(interactions) == 0 {
   201  		return "There were no other interactions with this mock"
   202  	}
   203  	result := "Actual interactions with this mock were:\n"
   204  	for _, methodName := range sortedMethodNames(interactions) {
   205  		result += formatInvocations(methodName, interactions[methodName])
   206  	}
   207  	return result
   208  }
   209  
   210  func formatInvocations(methodName string, invocations []MethodInvocation) (result string) {
   211  	for _, invocation := range invocations {
   212  		result += "\t" + methodName + "(" + formatParams(invocation.params) + ")\n"
   213  	}
   214  	return
   215  }
   216  
   217  func formatParams(params []Param) (result string) {
   218  	for i, param := range params {
   219  		if i > 0 {
   220  			result += ", "
   221  		}
   222  		result += fmt.Sprintf("%#v", param)
   223  	}
   224  	return
   225  }
   226  
   227  func formatMatchers(matchers []ArgumentMatcher) (result string) {
   228  	for i, matcher := range matchers {
   229  		if i > 0 {
   230  			result += ", "
   231  		}
   232  		result += fmt.Sprintf("%v", matcher)
   233  	}
   234  	return
   235  }
   236  
   237  func sortedMethodNames(interactions map[string][]MethodInvocation) []string {
   238  	methodNames := make([]string, len(interactions))
   239  	i := 0
   240  	for key := range interactions {
   241  		methodNames[i] = key
   242  		i++
   243  	}
   244  	sort.Strings(methodNames)
   245  	return methodNames
   246  }
   247  
   248  func (genericMock *GenericMock) allInteractions() map[string][]MethodInvocation {
   249  	interactions := make(map[string][]MethodInvocation)
   250  	for methodName := range genericMock.mockedMethods {
   251  		for _, invocation := range genericMock.mockedMethods[methodName].invocations {
   252  			interactions[methodName] = append(interactions[methodName], invocation)
   253  		}
   254  	}
   255  	return interactions
   256  }
   257  
   258  type mockedMethod struct {
   259  	sync.Mutex
   260  	name        string
   261  	invocations []MethodInvocation
   262  	stubbings   Stubbings
   263  }
   264  
   265  func (method *mockedMethod) Invoke(params []Param) ReturnValues {
   266  	method.Lock()
   267  	method.invocations = append(method.invocations, MethodInvocation{params, globalInvocationCounter.nextNumber()})
   268  	method.Unlock()
   269  	stubbing := method.stubbings.find(params)
   270  	if stubbing == nil {
   271  		return ReturnValues{}
   272  	}
   273  	return stubbing.Invoke(params)
   274  }
   275  
   276  func (method *mockedMethod) stub(paramMatchers Matchers, callback func([]Param) ReturnValues) {
   277  	stubbing := method.stubbings.findByMatchers(paramMatchers)
   278  	if stubbing == nil {
   279  		stubbing = &Stubbing{paramMatchers: paramMatchers}
   280  		method.stubbings = append(method.stubbings, stubbing)
   281  	}
   282  	stubbing.callbackSequence = append(stubbing.callbackSequence, callback)
   283  }
   284  
   285  func (method *mockedMethod) removeLastInvocation() {
   286  	method.invocations = method.invocations[:len(method.invocations)-1]
   287  }
   288  
   289  func (method *mockedMethod) reset(paramMatchers Matchers) {
   290  	method.stubbings.removeByMatchers(paramMatchers)
   291  }
   292  
   293  type Counter struct {
   294  	count int
   295  	sync.Mutex
   296  }
   297  
   298  func (counter *Counter) nextNumber() (nextNumber int) {
   299  	counter.Lock()
   300  	defer counter.Unlock()
   301  
   302  	nextNumber = counter.count
   303  	counter.count++
   304  	return
   305  }
   306  
   307  var globalInvocationCounter = Counter{count: 1}
   308  
   309  type MethodInvocation struct {
   310  	params                   []Param
   311  	orderingInvocationNumber int
   312  }
   313  
   314  type Stubbings []*Stubbing
   315  
   316  func (stubbings Stubbings) find(params []Param) *Stubbing {
   317  	for i := len(stubbings) - 1; i >= 0; i-- {
   318  		if stubbings[i].paramMatchers.Matches(params) {
   319  			return stubbings[i]
   320  		}
   321  	}
   322  	return nil
   323  }
   324  
   325  func (stubbings Stubbings) findByMatchers(paramMatchers Matchers) *Stubbing {
   326  	for _, stubbing := range stubbings {
   327  		if matchersEqual(stubbing.paramMatchers, paramMatchers) {
   328  			return stubbing
   329  		}
   330  	}
   331  	return nil
   332  }
   333  
   334  func (stubbings *Stubbings) removeByMatchers(paramMatchers Matchers) {
   335  	for i, stubbing := range *stubbings {
   336  		if matchersEqual(stubbing.paramMatchers, paramMatchers) {
   337  			*stubbings = append((*stubbings)[:i], (*stubbings)[i+1:]...)
   338  		}
   339  	}
   340  }
   341  
   342  func matchersEqual(a, b Matchers) bool {
   343  	if len(a) != len(b) {
   344  		return false
   345  	}
   346  	for i := range a {
   347  		if !reflect.DeepEqual(a[i], b[i]) {
   348  			return false
   349  		}
   350  	}
   351  	return true
   352  }
   353  
   354  type Stubbing struct {
   355  	paramMatchers    Matchers
   356  	callbackSequence []func([]Param) ReturnValues
   357  	sequencePointer  int
   358  }
   359  
   360  func (stubbing *Stubbing) Invoke(params []Param) ReturnValues {
   361  	defer func() {
   362  		if stubbing.sequencePointer < len(stubbing.callbackSequence)-1 {
   363  			stubbing.sequencePointer++
   364  		}
   365  	}()
   366  	return stubbing.callbackSequence[stubbing.sequencePointer](params)
   367  }
   368  
   369  type Matchers []ArgumentMatcher
   370  
   371  func (matchers Matchers) Matches(params []Param) bool {
   372  	if len(matchers) != len(params) { // Technically, this is not an error. Variadic arguments can cause this
   373  		return false
   374  	}
   375  
   376  	for i := range params {
   377  		if !matchers[i].Matches(params[i]) {
   378  			return false
   379  		}
   380  	}
   381  	return true
   382  }
   383  
   384  func (matchers *Matchers) append(matcher ArgumentMatcher) {
   385  	*matchers = append(*matchers, matcher)
   386  }
   387  
   388  type ongoingStubbing struct {
   389  	genericMock   *GenericMock
   390  	MethodName    string
   391  	ParamMatchers []ArgumentMatcher
   392  	returnTypes   []reflect.Type
   393  }
   394  
   395  func When(invocation ...interface{}) *ongoingStubbing {
   396  	callIfIsFunc(invocation)
   397  	verify.Argument(lastInvocation != nil,
   398  		"When() requires an argument which has to be 'a method call on a mock'.")
   399  	defer func() {
   400  		lastInvocationMutex.Lock()
   401  		lastInvocation = nil
   402  		lastInvocationMutex.Unlock()
   403  
   404  		globalArgMatchers = nil
   405  	}()
   406  	lastInvocation.genericMock.mockedMethods[lastInvocation.MethodName].removeLastInvocation()
   407  
   408  	paramMatchers := paramMatchersFromArgMatchersOrParams(globalArgMatchers, lastInvocation.Params)
   409  	lastInvocation.genericMock.reset(lastInvocation.MethodName, paramMatchers)
   410  	return &ongoingStubbing{
   411  		genericMock:   lastInvocation.genericMock,
   412  		MethodName:    lastInvocation.MethodName,
   413  		ParamMatchers: paramMatchers,
   414  		returnTypes:   lastInvocation.ReturnTypes,
   415  	}
   416  }
   417  
   418  func callIfIsFunc(invocation []interface{}) {
   419  	if len(invocation) == 1 {
   420  		actualType := actualTypeOf(invocation[0])
   421  		if actualType != nil && actualType.Kind() == reflect.Func && !reflect.ValueOf(invocation[0]).IsNil() {
   422  			if !(actualType.NumIn() == 0 && actualType.NumOut() == 0) {
   423  				panic("When using 'When' with function that does not return a value, " +
   424  					"it expects a function with no arguments and no return value.")
   425  			}
   426  			reflect.ValueOf(invocation[0]).Call([]reflect.Value{})
   427  		}
   428  	}
   429  }
   430  
   431  // Deals with nils without panicking
   432  func actualTypeOf(iface interface{}) reflect.Type {
   433  	defer func() { recover() }()
   434  	return reflect.TypeOf(iface)
   435  }
   436  
   437  func paramMatchersFromArgMatchersOrParams(argMatchers []ArgumentMatcher, params []Param) []ArgumentMatcher {
   438  	if len(argMatchers) != 0 {
   439  		verifyArgMatcherUse(argMatchers, params)
   440  		return argMatchers
   441  	}
   442  	return transformParamsIntoEqMatchers(params)
   443  }
   444  
   445  func verifyArgMatcherUse(argMatchers []ArgumentMatcher, params []Param) {
   446  	verify.Argument(len(argMatchers) == len(params),
   447  		"Invalid use of matchers!\n\n %v matchers expected, %v recorded.\n\n"+
   448  			"This error may occur if matchers are combined with raw values:\n"+
   449  			"    //incorrect:\n"+
   450  			"    someFunc(AnyInt(), \"raw String\")\n"+
   451  			"When using matchers, all arguments have to be provided by matchers.\n"+
   452  			"For example:\n"+
   453  			"    //correct:\n"+
   454  			"    someFunc(AnyInt(), EqString(\"String by matcher\"))",
   455  		len(params), len(argMatchers),
   456  	)
   457  }
   458  
   459  func transformParamsIntoEqMatchers(params []Param) []ArgumentMatcher {
   460  	paramMatchers := make([]ArgumentMatcher, len(params))
   461  	for i, param := range params {
   462  		paramMatchers[i] = &EqMatcher{Value: param}
   463  	}
   464  	return paramMatchers
   465  }
   466  
   467  var (
   468  	genericMocksMutex sync.Mutex
   469  	genericMocks      = make(map[Mock]*GenericMock)
   470  )
   471  
   472  func GetGenericMockFrom(mock Mock) *GenericMock {
   473  	genericMocksMutex.Lock()
   474  	defer genericMocksMutex.Unlock()
   475  	if genericMocks[mock] == nil {
   476  		genericMocks[mock] = &GenericMock{
   477  			mockedMethods: make(map[string]*mockedMethod),
   478  			fail:          mock.FailHandler(),
   479  		}
   480  	}
   481  	return genericMocks[mock]
   482  }
   483  
   484  func (stubbing *ongoingStubbing) ThenReturn(values ...ReturnValue) *ongoingStubbing {
   485  	checkAssignabilityOf(values, stubbing.returnTypes)
   486  	stubbing.genericMock.stub(stubbing.MethodName, stubbing.ParamMatchers, values)
   487  	return stubbing
   488  }
   489  
   490  func checkAssignabilityOf(stubbedReturnValues []ReturnValue, expectedReturnTypes []reflect.Type) {
   491  	verify.Argument(len(stubbedReturnValues) == len(expectedReturnTypes),
   492  		"Different number of return values")
   493  	for i := range stubbedReturnValues {
   494  		if stubbedReturnValues[i] == nil {
   495  			switch expectedReturnTypes[i].Kind() {
   496  			case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint,
   497  				reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32,
   498  				reflect.Float64, reflect.Complex64, reflect.Complex128, reflect.Array, reflect.String,
   499  				reflect.Struct:
   500  				panic("Return value 'nil' not assignable to return type " + expectedReturnTypes[i].Kind().String())
   501  			}
   502  		} else {
   503  			verify.Argument(reflect.TypeOf(stubbedReturnValues[i]).AssignableTo(expectedReturnTypes[i]),
   504  				"Return value of type %T not assignable to return type %v", stubbedReturnValues[i], expectedReturnTypes[i])
   505  		}
   506  	}
   507  }
   508  
   509  func (stubbing *ongoingStubbing) ThenPanic(v interface{}) *ongoingStubbing {
   510  	stubbing.genericMock.stubWithCallback(
   511  		stubbing.MethodName,
   512  		stubbing.ParamMatchers,
   513  		func([]Param) ReturnValues { panic(v) })
   514  	return stubbing
   515  }
   516  
   517  func (stubbing *ongoingStubbing) Then(callback func([]Param) ReturnValues) *ongoingStubbing {
   518  	stubbing.genericMock.stubWithCallback(
   519  		stubbing.MethodName,
   520  		stubbing.ParamMatchers,
   521  		callback)
   522  	return stubbing
   523  }
   524  
   525  type InOrderContext struct {
   526  	invocationCounter       int
   527  	lastInvokedMethodName   string
   528  	lastInvokedMethodParams []Param
   529  }
   530  
   531  // ArgumentMatcher can be used to match arguments.
   532  type ArgumentMatcher interface {
   533  	Matches(param Param) bool
   534  	fmt.Stringer
   535  }
   536  
   537  // InvocationCountMatcher can be used to match invocation counts. It is guaranteed that
   538  // FailureMessage will always be called after Matches so an implementation can save state.
   539  type InvocationCountMatcher interface {
   540  	Matches(param Param) bool
   541  	FailureMessage() string
   542  }
   543  
   544  // Matcher can be used to match arguments as well as invocation counts.
   545  // Note that support for overlapping embedded interfaces was added in Go 1.14, which is why
   546  // ArgumentMatcher and InvocationCountMatcher are not embedded here.
   547  type Matcher interface {
   548  	Matches(param Param) bool
   549  	FailureMessage() string
   550  	fmt.Stringer
   551  }
   552  
   553  func DumpInvocationsFor(mock Mock) {
   554  	fmt.Print(SDumpInvocationsFor(mock))
   555  }
   556  
   557  func SDumpInvocationsFor(mock Mock) string {
   558  	result := &bytes.Buffer{}
   559  	for _, mockedMethod := range GetGenericMockFrom(mock).mockedMethods {
   560  		for _, invocation := range mockedMethod.invocations {
   561  			fmt.Fprintf(result, "Method invocation: %v (\n", mockedMethod.name)
   562  			for _, param := range invocation.params {
   563  				fmt.Fprint(result, format.Object(param, 1), ",\n")
   564  			}
   565  			fmt.Fprintln(result, ")")
   566  		}
   567  	}
   568  	return result.String()
   569  }
   570  
   571  // InterceptMockFailures runs a given callback and returns an array of
   572  // failure messages generated by any Pegomock verifications within the callback.
   573  //
   574  // This is accomplished by temporarily replacing the *global* fail handler
   575  // with a fail handler that simply annotates failures.  The original fail handler
   576  // is reset when InterceptMockFailures returns.
   577  func InterceptMockFailures(f func()) []string {
   578  	originalHandler := GlobalFailHandler
   579  	failures := []string{}
   580  	RegisterMockFailHandler(func(message string, callerSkip ...int) {
   581  		failures = append(failures, message)
   582  	})
   583  	f()
   584  	RegisterMockFailHandler(originalHandler)
   585  	return failures
   586  }