github.com/bytedance/mockey@v1.2.10/mock.go (about)

     1  /*
     2   * Copyright 2022 ByteDance Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package mockey
    18  
    19  import (
    20  	"reflect"
    21  	"sync"
    22  	"sync/atomic"
    23  
    24  	"github.com/bytedance/mockey/internal/monkey"
    25  	"github.com/bytedance/mockey/internal/tool"
    26  )
    27  
    28  type FilterGoroutineType int64
    29  
    30  const (
    31  	Disable FilterGoroutineType = 0
    32  	Include FilterGoroutineType = 1
    33  	Exclude FilterGoroutineType = 2
    34  )
    35  
    36  type Mocker struct {
    37  	target    reflect.Value // mock target value
    38  	hook      reflect.Value // mock hook
    39  	proxy     interface{}   // proxy function to origin
    40  	times     int64
    41  	mockTimes int64
    42  	patch     *monkey.Patch
    43  	lock      sync.Mutex
    44  	isPatched bool
    45  	builder   *MockBuilder
    46  
    47  	outerCaller tool.CallerInfo
    48  }
    49  
    50  type MockBuilder struct {
    51  	target          interface{}      // mock target
    52  	proxyCaller     interface{}      // origin function caller hook
    53  	conditions      []*mockCondition // mock conditions
    54  	filterGoroutine FilterGoroutineType
    55  	gId             int64
    56  	unsafe          bool
    57  	generic         bool
    58  }
    59  
    60  // Mock mocks target function
    61  //
    62  // If target is a generic method or method of generic types, you need add a genericOpt, like this:
    63  //
    64  //	func f[int, float64](x int, y T1) T2
    65  //	Mock(f[int, float64], OptGeneric)
    66  func Mock(target interface{}, opt ...optionFn) *MockBuilder {
    67  	tool.AssertFunc(target)
    68  
    69  	option := resolveOpt(opt...)
    70  
    71  	builder := &MockBuilder{
    72  		target:  target,
    73  		unsafe:  option.unsafe,
    74  		generic: option.generic,
    75  	}
    76  	builder.resetCondition()
    77  	return builder
    78  }
    79  
    80  // MockUnsafe has the full ability of the Mock function and removes some security restrictions. This is an alternative
    81  // when the Mock function fails. It may cause some unknown problems, so we recommend using Mock under normal conditions.
    82  func MockUnsafe(target interface{}) *MockBuilder {
    83  	return Mock(target, OptUnsafe)
    84  }
    85  
    86  func (builder *MockBuilder) hookType() reflect.Type {
    87  	targetType := reflect.TypeOf(builder.target)
    88  	if builder.generic {
    89  		targetIn := []reflect.Type{genericInfoType}
    90  		for i := 0; i < targetType.NumIn(); i++ {
    91  			targetIn = append(targetIn, targetType.In(i))
    92  		}
    93  		targetOut := []reflect.Type{}
    94  		for i := 0; i < targetType.NumOut(); i++ {
    95  			targetOut = append(targetOut, targetType.Out(i))
    96  		}
    97  		return reflect.FuncOf(targetIn, targetOut, targetType.IsVariadic())
    98  	}
    99  	return targetType
   100  }
   101  
   102  func (builder *MockBuilder) resetCondition() *MockBuilder {
   103  	builder.conditions = []*mockCondition{builder.newCondition()} // at least 1 condition is needed
   104  	return builder
   105  }
   106  
   107  // Origin add an origin hook which can be used to call un-mocked origin function
   108  //
   109  // For example:
   110  //
   111  //	 origin := Fun // only need the same type
   112  //	 mock := func(p string) string {
   113  //		 return origin(p + "mocked")
   114  //	 }
   115  //	 mock2 := Mock(Fun).To(mock).Origin(&origin).Build()
   116  //
   117  // Origin only works when call origin hook directly, target will still be mocked in recursive call
   118  func (builder *MockBuilder) Origin(funcPtr interface{}) *MockBuilder {
   119  	tool.Assert(builder.proxyCaller == nil, "re-set builder origin")
   120  	return builder.origin(funcPtr)
   121  }
   122  
   123  func (builder *MockBuilder) origin(funcPtr interface{}) *MockBuilder {
   124  	tool.AssertPtr(funcPtr)
   125  	builder.proxyCaller = funcPtr
   126  	return builder
   127  }
   128  
   129  func (builder *MockBuilder) lastCondition() *mockCondition {
   130  	cond := builder.conditions[len(builder.conditions)-1]
   131  	if cond.Complete() {
   132  		cond = builder.newCondition()
   133  		builder.conditions = append(builder.conditions, cond)
   134  	}
   135  	return cond
   136  }
   137  
   138  func (builder *MockBuilder) newCondition() *mockCondition {
   139  	return &mockCondition{builder: builder}
   140  }
   141  
   142  // When declares the condition hook that's called to determine whether the mock should be executed.
   143  //
   144  // The condition hook function must have the same parameters as the target function.
   145  //
   146  // The following example would execute the mock when input int is negative
   147  //
   148  //	func Fun(input int) string {
   149  //		return strconv.Itoa(input)
   150  //	}
   151  //	Mock(Fun).When(func(input int) bool { return input < 0 }).Return("0").Build()
   152  //
   153  // Note that if the target function is a struct method, you may optionally include
   154  // the receiver as the first argument of the condition hook function. For example,
   155  //
   156  //	type Foo struct {
   157  //		Age int
   158  //	}
   159  //	func (f *Foo) GetAge(younger int) string {
   160  //		return strconv.Itoa(f.Age - younger)
   161  //	}
   162  //	Mock((*Foo).GetAge).When(func(f *Foo, younger int) bool { return younger < 0 }).Return("0").Build()
   163  func (builder *MockBuilder) When(when interface{}) *MockBuilder {
   164  	builder.lastCondition().SetWhen(when)
   165  	return builder
   166  }
   167  
   168  // To declares the hook function that's called to replace the target function.
   169  //
   170  // The hook function must have the same signature as the target function.
   171  //
   172  // The following example would make Fun always return true
   173  //
   174  //	func Fun(input string) bool {
   175  //		return input == "fun"
   176  //	}
   177  //
   178  //	Mock(Fun).To(func(_ string) bool {return true}).Build()
   179  //
   180  // Note that if the target function is a struct method, you may optionally include
   181  // the receiver as the first argument of the hook function. For example,
   182  //
   183  //	type Foo struct {
   184  //		Name string
   185  //	}
   186  //	func (f *Foo) Bar(other string) bool {
   187  //		return other == f.Name
   188  //	}
   189  //	Mock((*Foo).Bar).To(func(f *Foo, other string) bool {return true}).Build()
   190  func (builder *MockBuilder) To(hook interface{}) *MockBuilder {
   191  	builder.lastCondition().SetTo(hook)
   192  	return builder
   193  }
   194  
   195  func (builder *MockBuilder) Return(results ...interface{}) *MockBuilder {
   196  	builder.lastCondition().SetReturn(results...)
   197  	return builder
   198  }
   199  
   200  func (builder *MockBuilder) IncludeCurrentGoRoutine() *MockBuilder {
   201  	return builder.FilterGoRoutine(Include, tool.GetGoroutineID())
   202  }
   203  
   204  func (builder *MockBuilder) ExcludeCurrentGoRoutine() *MockBuilder {
   205  	return builder.FilterGoRoutine(Exclude, tool.GetGoroutineID())
   206  }
   207  
   208  func (builder *MockBuilder) FilterGoRoutine(filter FilterGoroutineType, gId int64) *MockBuilder {
   209  	builder.filterGoroutine = filter
   210  	builder.gId = gId
   211  	return builder
   212  }
   213  
   214  func (builder *MockBuilder) Build() *Mocker {
   215  	mocker := Mocker{target: reflect.ValueOf(builder.target), builder: builder}
   216  	mocker.buildHook()
   217  	mocker.Patch()
   218  	return &mocker
   219  }
   220  
   221  func (mocker *Mocker) missReceiver(target reflect.Type, hook interface{}) bool {
   222  	hType := reflect.TypeOf(hook)
   223  	tool.Assert(hType.Kind() == reflect.Func, "Param(%v) a is not a func", hType.Kind())
   224  	tool.Assert(target.IsVariadic() == hType.IsVariadic(), "target:%v, hook:%v args not match", target, hook)
   225  	// has receiver
   226  	if tool.CheckFuncArgs(target, hType, 0, 0) {
   227  		return false
   228  	}
   229  	if tool.CheckFuncArgs(target, hType, 1, 0) {
   230  		return true
   231  	}
   232  	tool.Assert(false, "target:%v, hook:%v args not match", target, hook)
   233  	return false
   234  }
   235  
   236  func (mocker *Mocker) buildHook() {
   237  	proxySetter := mocker.buildProxy()
   238  
   239  	originExec := func(args []reflect.Value) []reflect.Value {
   240  		return tool.ReflectCall(reflect.ValueOf(mocker.proxy).Elem(), args)
   241  	}
   242  
   243  	match := []func(args []reflect.Value) bool{}
   244  	exec := []func(args []reflect.Value) []reflect.Value{}
   245  
   246  	for i := range mocker.builder.conditions {
   247  		condition := mocker.builder.conditions[i]
   248  		if condition.when == nil {
   249  			// when condition is not set, just go into hook exec
   250  			match = append(match, func(args []reflect.Value) bool { return true })
   251  		} else {
   252  			match = append(match, func(args []reflect.Value) bool {
   253  				return tool.ReflectCall(reflect.ValueOf(condition.when), args)[0].Bool()
   254  			})
   255  		}
   256  
   257  		if condition.hook == nil {
   258  			// hook condition is not set, just go into original exec
   259  			exec = append(exec, originExec)
   260  		} else {
   261  			exec = append(exec, func(args []reflect.Value) []reflect.Value {
   262  				mocker.mock()
   263  				return tool.ReflectCall(reflect.ValueOf(condition.hook), args)
   264  			})
   265  		}
   266  	}
   267  
   268  	mockerHook := reflect.MakeFunc(mocker.builder.hookType(), func(args []reflect.Value) []reflect.Value {
   269  		proxySetter(args) // 设置origin调用proxy
   270  
   271  		mocker.access()
   272  		switch mocker.builder.filterGoroutine {
   273  		case Disable:
   274  			break
   275  		case Include:
   276  			if tool.GetGoroutineID() != mocker.builder.gId {
   277  				return originExec(args)
   278  			}
   279  		case Exclude:
   280  			if tool.GetGoroutineID() == mocker.builder.gId {
   281  				return originExec(args)
   282  			}
   283  		}
   284  
   285  		for i, matchFn := range match {
   286  			execFn := exec[i]
   287  			if matchFn(args) {
   288  				return execFn(args)
   289  			}
   290  		}
   291  
   292  		return originExec(args)
   293  	})
   294  	mocker.hook = mockerHook
   295  }
   296  
   297  // buildProx create a proxyCaller which could call origin directly
   298  func (mocker *Mocker) buildProxy() func(args []reflect.Value) {
   299  	proxy := reflect.New(mocker.builder.hookType())
   300  
   301  	proxyCallerSetter := func(args []reflect.Value) {}
   302  	if mocker.builder.proxyCaller != nil {
   303  		pVal := reflect.ValueOf(mocker.builder.proxyCaller)
   304  		tool.Assert(pVal.Kind() == reflect.Ptr && pVal.Elem().Kind() == reflect.Func, "origin receiver must be a function pointer")
   305  		pElem := pVal.Elem()
   306  
   307  		shift := 0
   308  		if mocker.builder.generic {
   309  			shift += 1
   310  		}
   311  		if mocker.missReceiver(mocker.target.Type(), pElem.Interface()) {
   312  			shift += 1
   313  		}
   314  		proxyCallerSetter = func(args []reflect.Value) {
   315  			pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
   316  				return tool.ReflectCall(proxy.Elem(), append(args[0:shift], innerArgs...))
   317  			}))
   318  		}
   319  	}
   320  	mocker.proxy = proxy.Interface()
   321  	return proxyCallerSetter
   322  }
   323  
   324  func (mocker *Mocker) Patch() *Mocker {
   325  	mocker.lock.Lock()
   326  	defer mocker.lock.Unlock()
   327  	if mocker.isPatched {
   328  		return mocker
   329  	}
   330  	mocker.patch = monkey.PatchValue(mocker.target, mocker.hook, reflect.ValueOf(mocker.proxy), mocker.builder.unsafe, mocker.builder.generic)
   331  	mocker.isPatched = true
   332  	addToGlobal(mocker)
   333  
   334  	mocker.outerCaller = tool.OuterCaller()
   335  	return mocker
   336  }
   337  
   338  func (mocker *Mocker) UnPatch() *Mocker {
   339  	mocker.lock.Lock()
   340  	defer mocker.lock.Unlock()
   341  	if !mocker.isPatched {
   342  		return mocker
   343  	}
   344  	mocker.patch.Unpatch()
   345  	mocker.isPatched = false
   346  	removeFromGlobal(mocker)
   347  	atomic.StoreInt64(&mocker.times, 0)
   348  	atomic.StoreInt64(&mocker.mockTimes, 0)
   349  
   350  	return mocker
   351  }
   352  
   353  func (mocker *Mocker) Release() *MockBuilder {
   354  	mocker.UnPatch()
   355  	mocker.builder.resetCondition()
   356  	return mocker.builder
   357  }
   358  
   359  func (mocker *Mocker) ExcludeCurrentGoRoutine() *Mocker {
   360  	return mocker.rePatch(func() {
   361  		mocker.builder.ExcludeCurrentGoRoutine()
   362  	})
   363  }
   364  
   365  func (mocker *Mocker) FilterGoRoutine(filter FilterGoroutineType, gId int64) *Mocker {
   366  	return mocker.rePatch(func() {
   367  		mocker.builder.FilterGoRoutine(filter, gId)
   368  	})
   369  }
   370  
   371  func (mocker *Mocker) IncludeCurrentGoRoutine() *Mocker {
   372  	return mocker.rePatch(func() {
   373  		mocker.builder.IncludeCurrentGoRoutine()
   374  	})
   375  }
   376  
   377  func (mocker *Mocker) When(when interface{}) *Mocker {
   378  	tool.Assert(len(mocker.builder.conditions) == 1, "only one-condition mocker could reset when (You can call Release first, then rebuild mocker)")
   379  
   380  	return mocker.rePatch(func() {
   381  		mocker.builder.conditions[0].SetWhenForce(when)
   382  	})
   383  }
   384  
   385  func (mocker *Mocker) To(to interface{}) *Mocker {
   386  	tool.Assert(len(mocker.builder.conditions) == 1, "only one-condition mocker could reset to  (You can call Release first, then rebuild mocker)")
   387  
   388  	return mocker.rePatch(func() {
   389  		mocker.builder.conditions[0].SetToForce(to)
   390  	})
   391  }
   392  
   393  func (mocker *Mocker) Return(results ...interface{}) *Mocker {
   394  	tool.Assert(len(mocker.builder.conditions) == 1, "only one-condition mocker could reset return  (You can call Release first, then rebuild mocker)")
   395  
   396  	return mocker.rePatch(func() {
   397  		mocker.builder.conditions[0].SetReturnForce(results...)
   398  	})
   399  }
   400  
   401  func (mocker *Mocker) Origin(funcPtr interface{}) *Mocker {
   402  	return mocker.rePatch(func() {
   403  		mocker.builder.origin(funcPtr)
   404  	})
   405  }
   406  
   407  func (mocker *Mocker) rePatch(do func()) *Mocker {
   408  	mocker.UnPatch()
   409  	do()
   410  	mocker.buildHook()
   411  	mocker.Patch()
   412  	return mocker
   413  }
   414  
   415  func (mocker *Mocker) access() {
   416  	atomic.AddInt64(&mocker.times, 1)
   417  }
   418  
   419  func (mocker *Mocker) mock() {
   420  	atomic.AddInt64(&mocker.mockTimes, 1)
   421  }
   422  
   423  func (mocker *Mocker) Times() int {
   424  	return int(atomic.LoadInt64(&mocker.times))
   425  }
   426  
   427  func (mocker *Mocker) MockTimes() int {
   428  	return int(atomic.LoadInt64(&mocker.mockTimes))
   429  }
   430  
   431  func (mocker *Mocker) key() uintptr {
   432  	return mocker.target.Pointer()
   433  }
   434  
   435  func (mocker *Mocker) name() string {
   436  	return mocker.target.String()
   437  }
   438  
   439  func (mocker *Mocker) unPatch() {
   440  	mocker.UnPatch()
   441  }
   442  
   443  func (mocker *Mocker) caller() tool.CallerInfo {
   444  	return mocker.outerCaller
   445  }