github.com/joomcode/pegomock@v2.9.2-0.20220414140958-14f53b6b2a6c+incompatible/dsl.go (about) 1 // Copyright 2015 Peter Goetz 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package pegomock 16 17 import ( 18 "bytes" 19 "fmt" 20 "reflect" 21 "sort" 22 "sync" 23 "testing" 24 "time" 25 26 "github.com/onsi/gomega/format" 27 "github.com/petergtz/pegomock/internal/verify" 28 ) 29 30 var GlobalFailHandler FailHandler 31 32 func RegisterMockFailHandler(handler FailHandler) { 33 GlobalFailHandler = handler 34 } 35 func RegisterMockTestingT(t *testing.T) { 36 RegisterMockFailHandler(BuildTestingTFailHandler(t)) 37 } 38 39 var ( 40 lastInvocation *invocation 41 lastInvocationMutex sync.Mutex 42 ) 43 44 var globalArgMatchers Matchers 45 46 func RegisterMatcher(matcher ArgumentMatcher) { 47 globalArgMatchers.append(matcher) 48 } 49 50 type invocation struct { 51 genericMock *GenericMock 52 MethodName string 53 Params []Param 54 ReturnTypes []reflect.Type 55 } 56 57 type GenericMock struct { 58 sync.Mutex 59 mockedMethods map[string]*mockedMethod 60 fail FailHandler 61 } 62 63 func (genericMock *GenericMock) Invoke(methodName string, params []Param, returnTypes []reflect.Type) ReturnValues { 64 lastInvocationMutex.Lock() 65 lastInvocation = &invocation{ 66 genericMock: genericMock, 67 MethodName: methodName, 68 Params: params, 69 ReturnTypes: returnTypes, 70 } 71 lastInvocationMutex.Unlock() 72 return genericMock.getOrCreateMockedMethod(methodName).Invoke(params) 73 } 74 75 func (genericMock *GenericMock) stub(methodName string, paramMatchers []ArgumentMatcher, returnValues ReturnValues) { 76 genericMock.stubWithCallback(methodName, paramMatchers, func([]Param) ReturnValues { return returnValues }) 77 } 78 79 func (genericMock *GenericMock) stubWithCallback(methodName string, paramMatchers []ArgumentMatcher, callback func([]Param) ReturnValues) { 80 genericMock.getOrCreateMockedMethod(methodName).stub(paramMatchers, callback) 81 } 82 83 func (genericMock *GenericMock) getOrCreateMockedMethod(methodName string) *mockedMethod { 84 genericMock.Lock() 85 defer genericMock.Unlock() 86 if _, ok := genericMock.mockedMethods[methodName]; !ok { 87 genericMock.mockedMethods[methodName] = &mockedMethod{name: methodName} 88 } 89 return genericMock.mockedMethods[methodName] 90 } 91 92 func (genericMock *GenericMock) reset(methodName string, paramMatchers []ArgumentMatcher) { 93 genericMock.getOrCreateMockedMethod(methodName).reset(paramMatchers) 94 } 95 96 func (genericMock *GenericMock) Verify( 97 inOrderContext *InOrderContext, 98 invocationCountMatcher InvocationCountMatcher, 99 methodName string, 100 params []Param, 101 options ...interface{}, 102 ) []MethodInvocation { 103 var timeout time.Duration 104 if len(options) == 1 { 105 timeout = options[0].(time.Duration) 106 } 107 if genericMock.fail == nil && GlobalFailHandler == nil { 108 panic("No FailHandler set. Please use either RegisterMockFailHandler or RegisterMockTestingT or TODO to set a fail handler.") 109 } 110 fail := GlobalFailHandler 111 if genericMock.fail != nil { 112 fail = genericMock.fail 113 } 114 defer func() { globalArgMatchers = nil }() // We don't want a panic somewhere during verification screw our global argMatchers 115 116 if len(globalArgMatchers) != 0 { 117 verifyArgMatcherUse(globalArgMatchers, params) 118 } 119 startTime := time.Now() 120 // timeoutLoop: 121 for { 122 genericMock.Lock() 123 methodInvocations := genericMock.methodInvocations(methodName, params, globalArgMatchers) 124 genericMock.Unlock() 125 if inOrderContext != nil { 126 for _, methodInvocation := range methodInvocations { 127 if methodInvocation.orderingInvocationNumber <= inOrderContext.invocationCounter { 128 // TODO: should introduce the following, in case we decide support "inorder" and "eventually" 129 // if time.Since(startTime) < timeout { 130 // continue timeoutLoop 131 // } 132 fail(fmt.Sprintf("Expected function call %v(%v) before function call %v(%v)", 133 methodName, formatParams(params), inOrderContext.lastInvokedMethodName, formatParams(inOrderContext.lastInvokedMethodParams))) 134 } 135 inOrderContext.invocationCounter = methodInvocation.orderingInvocationNumber 136 inOrderContext.lastInvokedMethodName = methodName 137 inOrderContext.lastInvokedMethodParams = params 138 } 139 } 140 if !invocationCountMatcher.Matches(len(methodInvocations)) { 141 if time.Since(startTime) < timeout { 142 time.Sleep(10 * time.Millisecond) 143 continue 144 } 145 var paramsOrMatchers interface{} = formatParams(params) 146 if len(globalArgMatchers) != 0 { 147 paramsOrMatchers = formatMatchers(globalArgMatchers) 148 } 149 timeoutInfo := "" 150 if timeout > 0 { 151 timeoutInfo = fmt.Sprintf(" after timeout of %v", timeout) 152 } 153 fail(fmt.Sprintf( 154 "Mock invocation count for %v(%v) does not match expectation%v.\n\n\t%v\n\n\t%v", 155 methodName, paramsOrMatchers, timeoutInfo, invocationCountMatcher.FailureMessage(), formatInteractions(genericMock.allInteractions()))) 156 } 157 return methodInvocations 158 } 159 } 160 161 // TODO this doesn't need to be a method, can be a free function 162 func (genericMock *GenericMock) GetInvocationParams(methodInvocations []MethodInvocation) [][]Param { 163 if len(methodInvocations) == 0 { 164 return nil 165 } 166 result := make([][]Param, len(methodInvocations[len(methodInvocations)-1].params)) 167 for i, invocation := range methodInvocations { 168 for u, param := range invocation.params { 169 if result[u] == nil { 170 result[u] = make([]Param, len(methodInvocations)) 171 } 172 result[u][i] = param 173 } 174 } 175 return result 176 } 177 178 func (genericMock *GenericMock) methodInvocations(methodName string, params []Param, matchers []ArgumentMatcher) []MethodInvocation { 179 var invocations []MethodInvocation 180 if method, exists := genericMock.mockedMethods[methodName]; exists { 181 method.Lock() 182 for _, invocation := range method.invocations { 183 if len(matchers) != 0 { 184 if Matchers(matchers).Matches(invocation.params) { 185 invocations = append(invocations, invocation) 186 } 187 } else { 188 if reflect.DeepEqual(params, invocation.params) || 189 (len(params) == 0 && len(invocation.params) == 0) { 190 invocations = append(invocations, invocation) 191 } 192 } 193 } 194 method.Unlock() 195 } 196 return invocations 197 } 198 199 func formatInteractions(interactions map[string][]MethodInvocation) string { 200 if len(interactions) == 0 { 201 return "There were no other interactions with this mock" 202 } 203 result := "Actual interactions with this mock were:\n" 204 for _, methodName := range sortedMethodNames(interactions) { 205 result += formatInvocations(methodName, interactions[methodName]) 206 } 207 return result 208 } 209 210 func formatInvocations(methodName string, invocations []MethodInvocation) (result string) { 211 for _, invocation := range invocations { 212 result += "\t" + methodName + "(" + formatParams(invocation.params) + ")\n" 213 } 214 return 215 } 216 217 func formatParams(params []Param) (result string) { 218 for i, param := range params { 219 if i > 0 { 220 result += ", " 221 } 222 result += fmt.Sprintf("%#v", param) 223 } 224 return 225 } 226 227 func formatMatchers(matchers []ArgumentMatcher) (result string) { 228 for i, matcher := range matchers { 229 if i > 0 { 230 result += ", " 231 } 232 result += fmt.Sprintf("%v", matcher) 233 } 234 return 235 } 236 237 func sortedMethodNames(interactions map[string][]MethodInvocation) []string { 238 methodNames := make([]string, len(interactions)) 239 i := 0 240 for key := range interactions { 241 methodNames[i] = key 242 i++ 243 } 244 sort.Strings(methodNames) 245 return methodNames 246 } 247 248 func (genericMock *GenericMock) allInteractions() map[string][]MethodInvocation { 249 interactions := make(map[string][]MethodInvocation) 250 for methodName := range genericMock.mockedMethods { 251 for _, invocation := range genericMock.mockedMethods[methodName].invocations { 252 interactions[methodName] = append(interactions[methodName], invocation) 253 } 254 } 255 return interactions 256 } 257 258 type mockedMethod struct { 259 sync.Mutex 260 name string 261 invocations []MethodInvocation 262 stubbings Stubbings 263 } 264 265 func (method *mockedMethod) Invoke(params []Param) ReturnValues { 266 method.Lock() 267 method.invocations = append(method.invocations, MethodInvocation{params, globalInvocationCounter.nextNumber()}) 268 method.Unlock() 269 stubbing := method.stubbings.find(params) 270 if stubbing == nil { 271 return ReturnValues{} 272 } 273 return stubbing.Invoke(params) 274 } 275 276 func (method *mockedMethod) stub(paramMatchers Matchers, callback func([]Param) ReturnValues) { 277 stubbing := method.stubbings.findByMatchers(paramMatchers) 278 if stubbing == nil { 279 stubbing = &Stubbing{paramMatchers: paramMatchers} 280 method.stubbings = append(method.stubbings, stubbing) 281 } 282 stubbing.callbackSequence = append(stubbing.callbackSequence, callback) 283 } 284 285 func (method *mockedMethod) removeLastInvocation() { 286 method.invocations = method.invocations[:len(method.invocations)-1] 287 } 288 289 func (method *mockedMethod) reset(paramMatchers Matchers) { 290 method.stubbings.removeByMatchers(paramMatchers) 291 } 292 293 type Counter struct { 294 count int 295 sync.Mutex 296 } 297 298 func (counter *Counter) nextNumber() (nextNumber int) { 299 counter.Lock() 300 defer counter.Unlock() 301 302 nextNumber = counter.count 303 counter.count++ 304 return 305 } 306 307 var globalInvocationCounter = Counter{count: 1} 308 309 type MethodInvocation struct { 310 params []Param 311 orderingInvocationNumber int 312 } 313 314 type Stubbings []*Stubbing 315 316 func (stubbings Stubbings) find(params []Param) *Stubbing { 317 for i := len(stubbings) - 1; i >= 0; i-- { 318 if stubbings[i].paramMatchers.Matches(params) { 319 return stubbings[i] 320 } 321 } 322 return nil 323 } 324 325 func (stubbings Stubbings) findByMatchers(paramMatchers Matchers) *Stubbing { 326 for _, stubbing := range stubbings { 327 if matchersEqual(stubbing.paramMatchers, paramMatchers) { 328 return stubbing 329 } 330 } 331 return nil 332 } 333 334 func (stubbings *Stubbings) removeByMatchers(paramMatchers Matchers) { 335 for i, stubbing := range *stubbings { 336 if matchersEqual(stubbing.paramMatchers, paramMatchers) { 337 *stubbings = append((*stubbings)[:i], (*stubbings)[i+1:]...) 338 } 339 } 340 } 341 342 func matchersEqual(a, b Matchers) bool { 343 if len(a) != len(b) { 344 return false 345 } 346 for i := range a { 347 if !reflect.DeepEqual(a[i], b[i]) { 348 return false 349 } 350 } 351 return true 352 } 353 354 type Stubbing struct { 355 paramMatchers Matchers 356 callbackSequence []func([]Param) ReturnValues 357 sequencePointer int 358 } 359 360 func (stubbing *Stubbing) Invoke(params []Param) ReturnValues { 361 defer func() { 362 if stubbing.sequencePointer < len(stubbing.callbackSequence)-1 { 363 stubbing.sequencePointer++ 364 } 365 }() 366 return stubbing.callbackSequence[stubbing.sequencePointer](params) 367 } 368 369 type Matchers []ArgumentMatcher 370 371 func (matchers Matchers) Matches(params []Param) bool { 372 if len(matchers) != len(params) { // Technically, this is not an error. Variadic arguments can cause this 373 return false 374 } 375 376 for i := range params { 377 if !matchers[i].Matches(params[i]) { 378 return false 379 } 380 } 381 return true 382 } 383 384 func (matchers *Matchers) append(matcher ArgumentMatcher) { 385 *matchers = append(*matchers, matcher) 386 } 387 388 type ongoingStubbing struct { 389 genericMock *GenericMock 390 MethodName string 391 ParamMatchers []ArgumentMatcher 392 returnTypes []reflect.Type 393 } 394 395 func When(invocation ...interface{}) *ongoingStubbing { 396 callIfIsFunc(invocation) 397 verify.Argument(lastInvocation != nil, 398 "When() requires an argument which has to be 'a method call on a mock'.") 399 defer func() { 400 lastInvocationMutex.Lock() 401 lastInvocation = nil 402 lastInvocationMutex.Unlock() 403 404 globalArgMatchers = nil 405 }() 406 lastInvocation.genericMock.mockedMethods[lastInvocation.MethodName].removeLastInvocation() 407 408 paramMatchers := paramMatchersFromArgMatchersOrParams(globalArgMatchers, lastInvocation.Params) 409 lastInvocation.genericMock.reset(lastInvocation.MethodName, paramMatchers) 410 return &ongoingStubbing{ 411 genericMock: lastInvocation.genericMock, 412 MethodName: lastInvocation.MethodName, 413 ParamMatchers: paramMatchers, 414 returnTypes: lastInvocation.ReturnTypes, 415 } 416 } 417 418 func callIfIsFunc(invocation []interface{}) { 419 if len(invocation) == 1 { 420 actualType := actualTypeOf(invocation[0]) 421 if actualType != nil && actualType.Kind() == reflect.Func && !reflect.ValueOf(invocation[0]).IsNil() { 422 if !(actualType.NumIn() == 0 && actualType.NumOut() == 0) { 423 panic("When using 'When' with function that does not return a value, " + 424 "it expects a function with no arguments and no return value.") 425 } 426 reflect.ValueOf(invocation[0]).Call([]reflect.Value{}) 427 } 428 } 429 } 430 431 // Deals with nils without panicking 432 func actualTypeOf(iface interface{}) reflect.Type { 433 defer func() { recover() }() 434 return reflect.TypeOf(iface) 435 } 436 437 func paramMatchersFromArgMatchersOrParams(argMatchers []ArgumentMatcher, params []Param) []ArgumentMatcher { 438 if len(argMatchers) != 0 { 439 verifyArgMatcherUse(argMatchers, params) 440 return argMatchers 441 } 442 return transformParamsIntoEqMatchers(params) 443 } 444 445 func verifyArgMatcherUse(argMatchers []ArgumentMatcher, params []Param) { 446 verify.Argument(len(argMatchers) == len(params), 447 "Invalid use of matchers!\n\n %v matchers expected, %v recorded.\n\n"+ 448 "This error may occur if matchers are combined with raw values:\n"+ 449 " //incorrect:\n"+ 450 " someFunc(AnyInt(), \"raw String\")\n"+ 451 "When using matchers, all arguments have to be provided by matchers.\n"+ 452 "For example:\n"+ 453 " //correct:\n"+ 454 " someFunc(AnyInt(), EqString(\"String by matcher\"))", 455 len(params), len(argMatchers), 456 ) 457 } 458 459 func transformParamsIntoEqMatchers(params []Param) []ArgumentMatcher { 460 paramMatchers := make([]ArgumentMatcher, len(params)) 461 for i, param := range params { 462 paramMatchers[i] = &EqMatcher{Value: param} 463 } 464 return paramMatchers 465 } 466 467 var ( 468 genericMocksMutex sync.Mutex 469 genericMocks = make(map[Mock]*GenericMock) 470 ) 471 472 func GetGenericMockFrom(mock Mock) *GenericMock { 473 genericMocksMutex.Lock() 474 defer genericMocksMutex.Unlock() 475 if genericMocks[mock] == nil { 476 genericMocks[mock] = &GenericMock{ 477 mockedMethods: make(map[string]*mockedMethod), 478 fail: mock.FailHandler(), 479 } 480 } 481 return genericMocks[mock] 482 } 483 484 func (stubbing *ongoingStubbing) ThenReturn(values ...ReturnValue) *ongoingStubbing { 485 checkAssignabilityOf(values, stubbing.returnTypes) 486 stubbing.genericMock.stub(stubbing.MethodName, stubbing.ParamMatchers, values) 487 return stubbing 488 } 489 490 func checkAssignabilityOf(stubbedReturnValues []ReturnValue, expectedReturnTypes []reflect.Type) { 491 verify.Argument(len(stubbedReturnValues) == len(expectedReturnTypes), 492 "Different number of return values") 493 for i := range stubbedReturnValues { 494 if stubbedReturnValues[i] == nil { 495 switch expectedReturnTypes[i].Kind() { 496 case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, 497 reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, 498 reflect.Float64, reflect.Complex64, reflect.Complex128, reflect.Array, reflect.String, 499 reflect.Struct: 500 panic("Return value 'nil' not assignable to return type " + expectedReturnTypes[i].Kind().String()) 501 } 502 } else { 503 verify.Argument(reflect.TypeOf(stubbedReturnValues[i]).AssignableTo(expectedReturnTypes[i]), 504 "Return value of type %T not assignable to return type %v", stubbedReturnValues[i], expectedReturnTypes[i]) 505 } 506 } 507 } 508 509 func (stubbing *ongoingStubbing) ThenPanic(v interface{}) *ongoingStubbing { 510 stubbing.genericMock.stubWithCallback( 511 stubbing.MethodName, 512 stubbing.ParamMatchers, 513 func([]Param) ReturnValues { panic(v) }) 514 return stubbing 515 } 516 517 func (stubbing *ongoingStubbing) Then(callback func([]Param) ReturnValues) *ongoingStubbing { 518 stubbing.genericMock.stubWithCallback( 519 stubbing.MethodName, 520 stubbing.ParamMatchers, 521 callback) 522 return stubbing 523 } 524 525 type InOrderContext struct { 526 invocationCounter int 527 lastInvokedMethodName string 528 lastInvokedMethodParams []Param 529 } 530 531 // ArgumentMatcher can be used to match arguments. 532 type ArgumentMatcher interface { 533 Matches(param Param) bool 534 fmt.Stringer 535 } 536 537 // InvocationCountMatcher can be used to match invocation counts. It is guaranteed that 538 // FailureMessage will always be called after Matches so an implementation can save state. 539 type InvocationCountMatcher interface { 540 Matches(param Param) bool 541 FailureMessage() string 542 } 543 544 // Matcher can be used to match arguments as well as invocation counts. 545 // Note that support for overlapping embedded interfaces was added in Go 1.14, which is why 546 // ArgumentMatcher and InvocationCountMatcher are not embedded here. 547 type Matcher interface { 548 Matches(param Param) bool 549 FailureMessage() string 550 fmt.Stringer 551 } 552 553 func DumpInvocationsFor(mock Mock) { 554 fmt.Print(SDumpInvocationsFor(mock)) 555 } 556 557 func SDumpInvocationsFor(mock Mock) string { 558 result := &bytes.Buffer{} 559 for _, mockedMethod := range GetGenericMockFrom(mock).mockedMethods { 560 for _, invocation := range mockedMethod.invocations { 561 fmt.Fprintf(result, "Method invocation: %v (\n", mockedMethod.name) 562 for _, param := range invocation.params { 563 fmt.Fprint(result, format.Object(param, 1), ",\n") 564 } 565 fmt.Fprintln(result, ")") 566 } 567 } 568 return result.String() 569 } 570 571 // InterceptMockFailures runs a given callback and returns an array of 572 // failure messages generated by any Pegomock verifications within the callback. 573 // 574 // This is accomplished by temporarily replacing the *global* fail handler 575 // with a fail handler that simply annotates failures. The original fail handler 576 // is reset when InterceptMockFailures returns. 577 func InterceptMockFailures(f func()) []string { 578 originalHandler := GlobalFailHandler 579 failures := []string{} 580 RegisterMockFailHandler(func(message string, callerSkip ...int) { 581 failures = append(failures, message) 582 }) 583 f() 584 RegisterMockFailHandler(originalHandler) 585 return failures 586 }