github.com/bytedance/mockey@v1.2.10/mock.go (about) 1 /* 2 * Copyright 2022 ByteDance Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package mockey 18 19 import ( 20 "reflect" 21 "sync" 22 "sync/atomic" 23 24 "github.com/bytedance/mockey/internal/monkey" 25 "github.com/bytedance/mockey/internal/tool" 26 ) 27 28 type FilterGoroutineType int64 29 30 const ( 31 Disable FilterGoroutineType = 0 32 Include FilterGoroutineType = 1 33 Exclude FilterGoroutineType = 2 34 ) 35 36 type Mocker struct { 37 target reflect.Value // mock target value 38 hook reflect.Value // mock hook 39 proxy interface{} // proxy function to origin 40 times int64 41 mockTimes int64 42 patch *monkey.Patch 43 lock sync.Mutex 44 isPatched bool 45 builder *MockBuilder 46 47 outerCaller tool.CallerInfo 48 } 49 50 type MockBuilder struct { 51 target interface{} // mock target 52 proxyCaller interface{} // origin function caller hook 53 conditions []*mockCondition // mock conditions 54 filterGoroutine FilterGoroutineType 55 gId int64 56 unsafe bool 57 generic bool 58 } 59 60 // Mock mocks target function 61 // 62 // If target is a generic method or method of generic types, you need add a genericOpt, like this: 63 // 64 // func f[int, float64](x int, y T1) T2 65 // Mock(f[int, float64], OptGeneric) 66 func Mock(target interface{}, opt ...optionFn) *MockBuilder { 67 tool.AssertFunc(target) 68 69 option := resolveOpt(opt...) 70 71 builder := &MockBuilder{ 72 target: target, 73 unsafe: option.unsafe, 74 generic: option.generic, 75 } 76 builder.resetCondition() 77 return builder 78 } 79 80 // MockUnsafe has the full ability of the Mock function and removes some security restrictions. This is an alternative 81 // when the Mock function fails. It may cause some unknown problems, so we recommend using Mock under normal conditions. 82 func MockUnsafe(target interface{}) *MockBuilder { 83 return Mock(target, OptUnsafe) 84 } 85 86 func (builder *MockBuilder) hookType() reflect.Type { 87 targetType := reflect.TypeOf(builder.target) 88 if builder.generic { 89 targetIn := []reflect.Type{genericInfoType} 90 for i := 0; i < targetType.NumIn(); i++ { 91 targetIn = append(targetIn, targetType.In(i)) 92 } 93 targetOut := []reflect.Type{} 94 for i := 0; i < targetType.NumOut(); i++ { 95 targetOut = append(targetOut, targetType.Out(i)) 96 } 97 return reflect.FuncOf(targetIn, targetOut, targetType.IsVariadic()) 98 } 99 return targetType 100 } 101 102 func (builder *MockBuilder) resetCondition() *MockBuilder { 103 builder.conditions = []*mockCondition{builder.newCondition()} // at least 1 condition is needed 104 return builder 105 } 106 107 // Origin add an origin hook which can be used to call un-mocked origin function 108 // 109 // For example: 110 // 111 // origin := Fun // only need the same type 112 // mock := func(p string) string { 113 // return origin(p + "mocked") 114 // } 115 // mock2 := Mock(Fun).To(mock).Origin(&origin).Build() 116 // 117 // Origin only works when call origin hook directly, target will still be mocked in recursive call 118 func (builder *MockBuilder) Origin(funcPtr interface{}) *MockBuilder { 119 tool.Assert(builder.proxyCaller == nil, "re-set builder origin") 120 return builder.origin(funcPtr) 121 } 122 123 func (builder *MockBuilder) origin(funcPtr interface{}) *MockBuilder { 124 tool.AssertPtr(funcPtr) 125 builder.proxyCaller = funcPtr 126 return builder 127 } 128 129 func (builder *MockBuilder) lastCondition() *mockCondition { 130 cond := builder.conditions[len(builder.conditions)-1] 131 if cond.Complete() { 132 cond = builder.newCondition() 133 builder.conditions = append(builder.conditions, cond) 134 } 135 return cond 136 } 137 138 func (builder *MockBuilder) newCondition() *mockCondition { 139 return &mockCondition{builder: builder} 140 } 141 142 // When declares the condition hook that's called to determine whether the mock should be executed. 143 // 144 // The condition hook function must have the same parameters as the target function. 145 // 146 // The following example would execute the mock when input int is negative 147 // 148 // func Fun(input int) string { 149 // return strconv.Itoa(input) 150 // } 151 // Mock(Fun).When(func(input int) bool { return input < 0 }).Return("0").Build() 152 // 153 // Note that if the target function is a struct method, you may optionally include 154 // the receiver as the first argument of the condition hook function. For example, 155 // 156 // type Foo struct { 157 // Age int 158 // } 159 // func (f *Foo) GetAge(younger int) string { 160 // return strconv.Itoa(f.Age - younger) 161 // } 162 // Mock((*Foo).GetAge).When(func(f *Foo, younger int) bool { return younger < 0 }).Return("0").Build() 163 func (builder *MockBuilder) When(when interface{}) *MockBuilder { 164 builder.lastCondition().SetWhen(when) 165 return builder 166 } 167 168 // To declares the hook function that's called to replace the target function. 169 // 170 // The hook function must have the same signature as the target function. 171 // 172 // The following example would make Fun always return true 173 // 174 // func Fun(input string) bool { 175 // return input == "fun" 176 // } 177 // 178 // Mock(Fun).To(func(_ string) bool {return true}).Build() 179 // 180 // Note that if the target function is a struct method, you may optionally include 181 // the receiver as the first argument of the hook function. For example, 182 // 183 // type Foo struct { 184 // Name string 185 // } 186 // func (f *Foo) Bar(other string) bool { 187 // return other == f.Name 188 // } 189 // Mock((*Foo).Bar).To(func(f *Foo, other string) bool {return true}).Build() 190 func (builder *MockBuilder) To(hook interface{}) *MockBuilder { 191 builder.lastCondition().SetTo(hook) 192 return builder 193 } 194 195 func (builder *MockBuilder) Return(results ...interface{}) *MockBuilder { 196 builder.lastCondition().SetReturn(results...) 197 return builder 198 } 199 200 func (builder *MockBuilder) IncludeCurrentGoRoutine() *MockBuilder { 201 return builder.FilterGoRoutine(Include, tool.GetGoroutineID()) 202 } 203 204 func (builder *MockBuilder) ExcludeCurrentGoRoutine() *MockBuilder { 205 return builder.FilterGoRoutine(Exclude, tool.GetGoroutineID()) 206 } 207 208 func (builder *MockBuilder) FilterGoRoutine(filter FilterGoroutineType, gId int64) *MockBuilder { 209 builder.filterGoroutine = filter 210 builder.gId = gId 211 return builder 212 } 213 214 func (builder *MockBuilder) Build() *Mocker { 215 mocker := Mocker{target: reflect.ValueOf(builder.target), builder: builder} 216 mocker.buildHook() 217 mocker.Patch() 218 return &mocker 219 } 220 221 func (mocker *Mocker) missReceiver(target reflect.Type, hook interface{}) bool { 222 hType := reflect.TypeOf(hook) 223 tool.Assert(hType.Kind() == reflect.Func, "Param(%v) a is not a func", hType.Kind()) 224 tool.Assert(target.IsVariadic() == hType.IsVariadic(), "target:%v, hook:%v args not match", target, hook) 225 // has receiver 226 if tool.CheckFuncArgs(target, hType, 0, 0) { 227 return false 228 } 229 if tool.CheckFuncArgs(target, hType, 1, 0) { 230 return true 231 } 232 tool.Assert(false, "target:%v, hook:%v args not match", target, hook) 233 return false 234 } 235 236 func (mocker *Mocker) buildHook() { 237 proxySetter := mocker.buildProxy() 238 239 originExec := func(args []reflect.Value) []reflect.Value { 240 return tool.ReflectCall(reflect.ValueOf(mocker.proxy).Elem(), args) 241 } 242 243 match := []func(args []reflect.Value) bool{} 244 exec := []func(args []reflect.Value) []reflect.Value{} 245 246 for i := range mocker.builder.conditions { 247 condition := mocker.builder.conditions[i] 248 if condition.when == nil { 249 // when condition is not set, just go into hook exec 250 match = append(match, func(args []reflect.Value) bool { return true }) 251 } else { 252 match = append(match, func(args []reflect.Value) bool { 253 return tool.ReflectCall(reflect.ValueOf(condition.when), args)[0].Bool() 254 }) 255 } 256 257 if condition.hook == nil { 258 // hook condition is not set, just go into original exec 259 exec = append(exec, originExec) 260 } else { 261 exec = append(exec, func(args []reflect.Value) []reflect.Value { 262 mocker.mock() 263 return tool.ReflectCall(reflect.ValueOf(condition.hook), args) 264 }) 265 } 266 } 267 268 mockerHook := reflect.MakeFunc(mocker.builder.hookType(), func(args []reflect.Value) []reflect.Value { 269 proxySetter(args) // 设置origin调用proxy 270 271 mocker.access() 272 switch mocker.builder.filterGoroutine { 273 case Disable: 274 break 275 case Include: 276 if tool.GetGoroutineID() != mocker.builder.gId { 277 return originExec(args) 278 } 279 case Exclude: 280 if tool.GetGoroutineID() == mocker.builder.gId { 281 return originExec(args) 282 } 283 } 284 285 for i, matchFn := range match { 286 execFn := exec[i] 287 if matchFn(args) { 288 return execFn(args) 289 } 290 } 291 292 return originExec(args) 293 }) 294 mocker.hook = mockerHook 295 } 296 297 // buildProx create a proxyCaller which could call origin directly 298 func (mocker *Mocker) buildProxy() func(args []reflect.Value) { 299 proxy := reflect.New(mocker.builder.hookType()) 300 301 proxyCallerSetter := func(args []reflect.Value) {} 302 if mocker.builder.proxyCaller != nil { 303 pVal := reflect.ValueOf(mocker.builder.proxyCaller) 304 tool.Assert(pVal.Kind() == reflect.Ptr && pVal.Elem().Kind() == reflect.Func, "origin receiver must be a function pointer") 305 pElem := pVal.Elem() 306 307 shift := 0 308 if mocker.builder.generic { 309 shift += 1 310 } 311 if mocker.missReceiver(mocker.target.Type(), pElem.Interface()) { 312 shift += 1 313 } 314 proxyCallerSetter = func(args []reflect.Value) { 315 pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) { 316 return tool.ReflectCall(proxy.Elem(), append(args[0:shift], innerArgs...)) 317 })) 318 } 319 } 320 mocker.proxy = proxy.Interface() 321 return proxyCallerSetter 322 } 323 324 func (mocker *Mocker) Patch() *Mocker { 325 mocker.lock.Lock() 326 defer mocker.lock.Unlock() 327 if mocker.isPatched { 328 return mocker 329 } 330 mocker.patch = monkey.PatchValue(mocker.target, mocker.hook, reflect.ValueOf(mocker.proxy), mocker.builder.unsafe, mocker.builder.generic) 331 mocker.isPatched = true 332 addToGlobal(mocker) 333 334 mocker.outerCaller = tool.OuterCaller() 335 return mocker 336 } 337 338 func (mocker *Mocker) UnPatch() *Mocker { 339 mocker.lock.Lock() 340 defer mocker.lock.Unlock() 341 if !mocker.isPatched { 342 return mocker 343 } 344 mocker.patch.Unpatch() 345 mocker.isPatched = false 346 removeFromGlobal(mocker) 347 atomic.StoreInt64(&mocker.times, 0) 348 atomic.StoreInt64(&mocker.mockTimes, 0) 349 350 return mocker 351 } 352 353 func (mocker *Mocker) Release() *MockBuilder { 354 mocker.UnPatch() 355 mocker.builder.resetCondition() 356 return mocker.builder 357 } 358 359 func (mocker *Mocker) ExcludeCurrentGoRoutine() *Mocker { 360 return mocker.rePatch(func() { 361 mocker.builder.ExcludeCurrentGoRoutine() 362 }) 363 } 364 365 func (mocker *Mocker) FilterGoRoutine(filter FilterGoroutineType, gId int64) *Mocker { 366 return mocker.rePatch(func() { 367 mocker.builder.FilterGoRoutine(filter, gId) 368 }) 369 } 370 371 func (mocker *Mocker) IncludeCurrentGoRoutine() *Mocker { 372 return mocker.rePatch(func() { 373 mocker.builder.IncludeCurrentGoRoutine() 374 }) 375 } 376 377 func (mocker *Mocker) When(when interface{}) *Mocker { 378 tool.Assert(len(mocker.builder.conditions) == 1, "only one-condition mocker could reset when (You can call Release first, then rebuild mocker)") 379 380 return mocker.rePatch(func() { 381 mocker.builder.conditions[0].SetWhenForce(when) 382 }) 383 } 384 385 func (mocker *Mocker) To(to interface{}) *Mocker { 386 tool.Assert(len(mocker.builder.conditions) == 1, "only one-condition mocker could reset to (You can call Release first, then rebuild mocker)") 387 388 return mocker.rePatch(func() { 389 mocker.builder.conditions[0].SetToForce(to) 390 }) 391 } 392 393 func (mocker *Mocker) Return(results ...interface{}) *Mocker { 394 tool.Assert(len(mocker.builder.conditions) == 1, "only one-condition mocker could reset return (You can call Release first, then rebuild mocker)") 395 396 return mocker.rePatch(func() { 397 mocker.builder.conditions[0].SetReturnForce(results...) 398 }) 399 } 400 401 func (mocker *Mocker) Origin(funcPtr interface{}) *Mocker { 402 return mocker.rePatch(func() { 403 mocker.builder.origin(funcPtr) 404 }) 405 } 406 407 func (mocker *Mocker) rePatch(do func()) *Mocker { 408 mocker.UnPatch() 409 do() 410 mocker.buildHook() 411 mocker.Patch() 412 return mocker 413 } 414 415 func (mocker *Mocker) access() { 416 atomic.AddInt64(&mocker.times, 1) 417 } 418 419 func (mocker *Mocker) mock() { 420 atomic.AddInt64(&mocker.mockTimes, 1) 421 } 422 423 func (mocker *Mocker) Times() int { 424 return int(atomic.LoadInt64(&mocker.times)) 425 } 426 427 func (mocker *Mocker) MockTimes() int { 428 return int(atomic.LoadInt64(&mocker.mockTimes)) 429 } 430 431 func (mocker *Mocker) key() uintptr { 432 return mocker.target.Pointer() 433 } 434 435 func (mocker *Mocker) name() string { 436 return mocker.target.String() 437 } 438 439 func (mocker *Mocker) unPatch() { 440 mocker.UnPatch() 441 } 442 443 func (mocker *Mocker) caller() tool.CallerInfo { 444 return mocker.outerCaller 445 }