github.com/golang/mock@v1.6.0/gomock/call.go (about) 1 // Copyright 2010 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package gomock 16 17 import ( 18 "fmt" 19 "reflect" 20 "strconv" 21 "strings" 22 ) 23 24 // Call represents an expected call to a mock. 25 type Call struct { 26 t TestHelper // for triggering test failures on invalid call setup 27 28 receiver interface{} // the receiver of the method call 29 method string // the name of the method 30 methodType reflect.Type // the type of the method 31 args []Matcher // the args 32 origin string // file and line number of call setup 33 34 preReqs []*Call // prerequisite calls 35 36 // Expectations 37 minCalls, maxCalls int 38 39 numCalls int // actual number made 40 41 // actions are called when this Call is called. Each action gets the args and 42 // can set the return values by returning a non-nil slice. Actions run in the 43 // order they are created. 44 actions []func([]interface{}) []interface{} 45 } 46 47 // newCall creates a *Call. It requires the method type in order to support 48 // unexported methods. 49 func newCall(t TestHelper, receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call { 50 t.Helper() 51 52 // TODO: check arity, types. 53 mArgs := make([]Matcher, len(args)) 54 for i, arg := range args { 55 if m, ok := arg.(Matcher); ok { 56 mArgs[i] = m 57 } else if arg == nil { 58 // Handle nil specially so that passing a nil interface value 59 // will match the typed nils of concrete args. 60 mArgs[i] = Nil() 61 } else { 62 mArgs[i] = Eq(arg) 63 } 64 } 65 66 // callerInfo's skip should be updated if the number of calls between the user's test 67 // and this line changes, i.e. this code is wrapped in another anonymous function. 68 // 0 is us, 1 is RecordCallWithMethodType(), 2 is the generated recorder, and 3 is the user's test. 69 origin := callerInfo(3) 70 actions := []func([]interface{}) []interface{}{func([]interface{}) []interface{} { 71 // Synthesize the zero value for each of the return args' types. 72 rets := make([]interface{}, methodType.NumOut()) 73 for i := 0; i < methodType.NumOut(); i++ { 74 rets[i] = reflect.Zero(methodType.Out(i)).Interface() 75 } 76 return rets 77 }} 78 return &Call{t: t, receiver: receiver, method: method, methodType: methodType, 79 args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions} 80 } 81 82 // AnyTimes allows the expectation to be called 0 or more times 83 func (c *Call) AnyTimes() *Call { 84 c.minCalls, c.maxCalls = 0, 1e8 // close enough to infinity 85 return c 86 } 87 88 // MinTimes requires the call to occur at least n times. If AnyTimes or MaxTimes have not been called or if MaxTimes 89 // was previously called with 1, MinTimes also sets the maximum number of calls to infinity. 90 func (c *Call) MinTimes(n int) *Call { 91 c.minCalls = n 92 if c.maxCalls == 1 { 93 c.maxCalls = 1e8 94 } 95 return c 96 } 97 98 // MaxTimes limits the number of calls to n times. If AnyTimes or MinTimes have not been called or if MinTimes was 99 // previously called with 1, MaxTimes also sets the minimum number of calls to 0. 100 func (c *Call) MaxTimes(n int) *Call { 101 c.maxCalls = n 102 if c.minCalls == 1 { 103 c.minCalls = 0 104 } 105 return c 106 } 107 108 // DoAndReturn declares the action to run when the call is matched. 109 // The return values from this function are returned by the mocked function. 110 // It takes an interface{} argument to support n-arity functions. 111 func (c *Call) DoAndReturn(f interface{}) *Call { 112 // TODO: Check arity and types here, rather than dying badly elsewhere. 113 v := reflect.ValueOf(f) 114 115 c.addAction(func(args []interface{}) []interface{} { 116 c.t.Helper() 117 vArgs := make([]reflect.Value, len(args)) 118 ft := v.Type() 119 if c.methodType.NumIn() != ft.NumIn() { 120 c.t.Fatalf("wrong number of arguments in DoAndReturn func for %T.%v: got %d, want %d [%s]", 121 c.receiver, c.method, ft.NumIn(), c.methodType.NumIn(), c.origin) 122 return nil 123 } 124 for i := 0; i < len(args); i++ { 125 if args[i] != nil { 126 vArgs[i] = reflect.ValueOf(args[i]) 127 } else { 128 // Use the zero value for the arg. 129 vArgs[i] = reflect.Zero(ft.In(i)) 130 } 131 } 132 vRets := v.Call(vArgs) 133 rets := make([]interface{}, len(vRets)) 134 for i, ret := range vRets { 135 rets[i] = ret.Interface() 136 } 137 return rets 138 }) 139 return c 140 } 141 142 // Do declares the action to run when the call is matched. The function's 143 // return values are ignored to retain backward compatibility. To use the 144 // return values call DoAndReturn. 145 // It takes an interface{} argument to support n-arity functions. 146 func (c *Call) Do(f interface{}) *Call { 147 // TODO: Check arity and types here, rather than dying badly elsewhere. 148 v := reflect.ValueOf(f) 149 150 c.addAction(func(args []interface{}) []interface{} { 151 c.t.Helper() 152 if c.methodType.NumIn() != v.Type().NumIn() { 153 c.t.Fatalf("wrong number of arguments in Do func for %T.%v: got %d, want %d [%s]", 154 c.receiver, c.method, v.Type().NumIn(), c.methodType.NumIn(), c.origin) 155 return nil 156 } 157 vArgs := make([]reflect.Value, len(args)) 158 ft := v.Type() 159 for i := 0; i < len(args); i++ { 160 if args[i] != nil { 161 vArgs[i] = reflect.ValueOf(args[i]) 162 } else { 163 // Use the zero value for the arg. 164 vArgs[i] = reflect.Zero(ft.In(i)) 165 } 166 } 167 v.Call(vArgs) 168 return nil 169 }) 170 return c 171 } 172 173 // Return declares the values to be returned by the mocked function call. 174 func (c *Call) Return(rets ...interface{}) *Call { 175 c.t.Helper() 176 177 mt := c.methodType 178 if len(rets) != mt.NumOut() { 179 c.t.Fatalf("wrong number of arguments to Return for %T.%v: got %d, want %d [%s]", 180 c.receiver, c.method, len(rets), mt.NumOut(), c.origin) 181 } 182 for i, ret := range rets { 183 if got, want := reflect.TypeOf(ret), mt.Out(i); got == want { 184 // Identical types; nothing to do. 185 } else if got == nil { 186 // Nil needs special handling. 187 switch want.Kind() { 188 case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: 189 // ok 190 default: 191 c.t.Fatalf("argument %d to Return for %T.%v is nil, but %v is not nillable [%s]", 192 i, c.receiver, c.method, want, c.origin) 193 } 194 } else if got.AssignableTo(want) { 195 // Assignable type relation. Make the assignment now so that the generated code 196 // can return the values with a type assertion. 197 v := reflect.New(want).Elem() 198 v.Set(reflect.ValueOf(ret)) 199 rets[i] = v.Interface() 200 } else { 201 c.t.Fatalf("wrong type of argument %d to Return for %T.%v: %v is not assignable to %v [%s]", 202 i, c.receiver, c.method, got, want, c.origin) 203 } 204 } 205 206 c.addAction(func([]interface{}) []interface{} { 207 return rets 208 }) 209 210 return c 211 } 212 213 // Times declares the exact number of times a function call is expected to be executed. 214 func (c *Call) Times(n int) *Call { 215 c.minCalls, c.maxCalls = n, n 216 return c 217 } 218 219 // SetArg declares an action that will set the nth argument's value, 220 // indirected through a pointer. Or, in the case of a slice, SetArg 221 // will copy value's elements into the nth argument. 222 func (c *Call) SetArg(n int, value interface{}) *Call { 223 c.t.Helper() 224 225 mt := c.methodType 226 // TODO: This will break on variadic methods. 227 // We will need to check those at invocation time. 228 if n < 0 || n >= mt.NumIn() { 229 c.t.Fatalf("SetArg(%d, ...) called for a method with %d args [%s]", 230 n, mt.NumIn(), c.origin) 231 } 232 // Permit setting argument through an interface. 233 // In the interface case, we don't (nay, can't) check the type here. 234 at := mt.In(n) 235 switch at.Kind() { 236 case reflect.Ptr: 237 dt := at.Elem() 238 if vt := reflect.TypeOf(value); !vt.AssignableTo(dt) { 239 c.t.Fatalf("SetArg(%d, ...) argument is a %v, not assignable to %v [%s]", 240 n, vt, dt, c.origin) 241 } 242 case reflect.Interface: 243 // nothing to do 244 case reflect.Slice: 245 // nothing to do 246 default: 247 c.t.Fatalf("SetArg(%d, ...) referring to argument of non-pointer non-interface non-slice type %v [%s]", 248 n, at, c.origin) 249 } 250 251 c.addAction(func(args []interface{}) []interface{} { 252 v := reflect.ValueOf(value) 253 switch reflect.TypeOf(args[n]).Kind() { 254 case reflect.Slice: 255 setSlice(args[n], v) 256 default: 257 reflect.ValueOf(args[n]).Elem().Set(v) 258 } 259 return nil 260 }) 261 return c 262 } 263 264 // isPreReq returns true if other is a direct or indirect prerequisite to c. 265 func (c *Call) isPreReq(other *Call) bool { 266 for _, preReq := range c.preReqs { 267 if other == preReq || preReq.isPreReq(other) { 268 return true 269 } 270 } 271 return false 272 } 273 274 // After declares that the call may only match after preReq has been exhausted. 275 func (c *Call) After(preReq *Call) *Call { 276 c.t.Helper() 277 278 if c == preReq { 279 c.t.Fatalf("A call isn't allowed to be its own prerequisite") 280 } 281 if preReq.isPreReq(c) { 282 c.t.Fatalf("Loop in call order: %v is a prerequisite to %v (possibly indirectly).", c, preReq) 283 } 284 285 c.preReqs = append(c.preReqs, preReq) 286 return c 287 } 288 289 // Returns true if the minimum number of calls have been made. 290 func (c *Call) satisfied() bool { 291 return c.numCalls >= c.minCalls 292 } 293 294 // Returns true if the maximum number of calls have been made. 295 func (c *Call) exhausted() bool { 296 return c.numCalls >= c.maxCalls 297 } 298 299 func (c *Call) String() string { 300 args := make([]string, len(c.args)) 301 for i, arg := range c.args { 302 args[i] = arg.String() 303 } 304 arguments := strings.Join(args, ", ") 305 return fmt.Sprintf("%T.%v(%s) %s", c.receiver, c.method, arguments, c.origin) 306 } 307 308 // Tests if the given call matches the expected call. 309 // If yes, returns nil. If no, returns error with message explaining why it does not match. 310 func (c *Call) matches(args []interface{}) error { 311 if !c.methodType.IsVariadic() { 312 if len(args) != len(c.args) { 313 return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: %d", 314 c.origin, len(args), len(c.args)) 315 } 316 317 for i, m := range c.args { 318 if !m.Matches(args[i]) { 319 return fmt.Errorf( 320 "expected call at %s doesn't match the argument at index %d.\nGot: %v\nWant: %v", 321 c.origin, i, formatGottenArg(m, args[i]), m, 322 ) 323 } 324 } 325 } else { 326 if len(c.args) < c.methodType.NumIn()-1 { 327 return fmt.Errorf("expected call at %s has the wrong number of matchers. Got: %d, want: %d", 328 c.origin, len(c.args), c.methodType.NumIn()-1) 329 } 330 if len(c.args) != c.methodType.NumIn() && len(args) != len(c.args) { 331 return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: %d", 332 c.origin, len(args), len(c.args)) 333 } 334 if len(args) < len(c.args)-1 { 335 return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: greater than or equal to %d", 336 c.origin, len(args), len(c.args)-1) 337 } 338 339 for i, m := range c.args { 340 if i < c.methodType.NumIn()-1 { 341 // Non-variadic args 342 if !m.Matches(args[i]) { 343 return fmt.Errorf("expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", 344 c.origin, strconv.Itoa(i), formatGottenArg(m, args[i]), m) 345 } 346 continue 347 } 348 // The last arg has a possibility of a variadic argument, so let it branch 349 350 // sample: Foo(a int, b int, c ...int) 351 if i < len(c.args) && i < len(args) { 352 if m.Matches(args[i]) { 353 // Got Foo(a, b, c) want Foo(matcherA, matcherB, gomock.Any()) 354 // Got Foo(a, b, c) want Foo(matcherA, matcherB, someSliceMatcher) 355 // Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC) 356 // Got Foo(a, b) want Foo(matcherA, matcherB) 357 // Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD) 358 continue 359 } 360 } 361 362 // The number of actual args don't match the number of matchers, 363 // or the last matcher is a slice and the last arg is not. 364 // If this function still matches it is because the last matcher 365 // matches all the remaining arguments or the lack of any. 366 // Convert the remaining arguments, if any, into a slice of the 367 // expected type. 368 vArgsType := c.methodType.In(c.methodType.NumIn() - 1) 369 vArgs := reflect.MakeSlice(vArgsType, 0, len(args)-i) 370 for _, arg := range args[i:] { 371 vArgs = reflect.Append(vArgs, reflect.ValueOf(arg)) 372 } 373 if m.Matches(vArgs.Interface()) { 374 // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, gomock.Any()) 375 // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, someSliceMatcher) 376 // Got Foo(a, b) want Foo(matcherA, matcherB, gomock.Any()) 377 // Got Foo(a, b) want Foo(matcherA, matcherB, someEmptySliceMatcher) 378 break 379 } 380 // Wrong number of matchers or not match. Fail. 381 // Got Foo(a, b) want Foo(matcherA, matcherB, matcherC, matcherD) 382 // Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC, matcherD) 383 // Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD, matcherE) 384 // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, matcherC, matcherD) 385 // Got Foo(a, b, c) want Foo(matcherA, matcherB) 386 387 return fmt.Errorf("expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", 388 c.origin, strconv.Itoa(i), formatGottenArg(m, args[i:]), c.args[i]) 389 } 390 } 391 392 // Check that all prerequisite calls have been satisfied. 393 for _, preReqCall := range c.preReqs { 394 if !preReqCall.satisfied() { 395 return fmt.Errorf("expected call at %s doesn't have a prerequisite call satisfied:\n%v\nshould be called before:\n%v", 396 c.origin, preReqCall, c) 397 } 398 } 399 400 // Check that the call is not exhausted. 401 if c.exhausted() { 402 return fmt.Errorf("expected call at %s has already been called the max number of times", c.origin) 403 } 404 405 return nil 406 } 407 408 // dropPrereqs tells the expected Call to not re-check prerequisite calls any 409 // longer, and to return its current set. 410 func (c *Call) dropPrereqs() (preReqs []*Call) { 411 preReqs = c.preReqs 412 c.preReqs = nil 413 return 414 } 415 416 func (c *Call) call() []func([]interface{}) []interface{} { 417 c.numCalls++ 418 return c.actions 419 } 420 421 // InOrder declares that the given calls should occur in order. 422 func InOrder(calls ...*Call) { 423 for i := 1; i < len(calls); i++ { 424 calls[i].After(calls[i-1]) 425 } 426 } 427 428 func setSlice(arg interface{}, v reflect.Value) { 429 va := reflect.ValueOf(arg) 430 for i := 0; i < v.Len(); i++ { 431 va.Index(i).Set(v.Index(i)) 432 } 433 } 434 435 func (c *Call) addAction(action func([]interface{}) []interface{}) { 436 c.actions = append(c.actions, action) 437 } 438 439 func formatGottenArg(m Matcher, arg interface{}) string { 440 got := fmt.Sprintf("%v (%T)", arg, arg) 441 if gs, ok := m.(GotFormatter); ok { 442 got = gs.Got(arg) 443 } 444 return got 445 }