github.com/gbl08ma/monkey@v1.1.0/monkey.go (about)

     1  package monkey // import "github.com/gbl08ma/monkey"
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"sync"
     7  	"unsafe"
     8  )
     9  
    10  //go:linkname stopTheWorld runtime.stopTheWorld
    11  func stopTheWorld(reason string)
    12  
    13  //go:linkname startTheWorld runtime.startTheWorld
    14  func startTheWorld()
    15  
    16  // patch is an applied patch
    17  // needed to undo a patch
    18  type patch struct {
    19  	originalBytes []byte
    20  	replacement   *reflect.Value
    21  }
    22  
    23  var (
    24  	lock = sync.Mutex{}
    25  
    26  	patches = make(map[uintptr]patch)
    27  )
    28  
    29  type value struct {
    30  	_   uintptr
    31  	ptr unsafe.Pointer
    32  }
    33  
    34  func getPtr(v reflect.Value) unsafe.Pointer {
    35  	return (*value)(unsafe.Pointer(&v)).ptr
    36  }
    37  
    38  type PatchGuard struct {
    39  	lock        sync.Mutex
    40  	target      reflect.Value
    41  	replacement reflect.Value
    42  }
    43  
    44  // Unpatch removes the patch
    45  func (g *PatchGuard) Unpatch() {
    46  	unpatchValue(g.target)
    47  }
    48  
    49  // Restore restores the replacement function
    50  func (g *PatchGuard) Restore() {
    51  	patchValue(g.target, g.replacement)
    52  }
    53  
    54  // UnpatchLock removes the patch in a thread-safe way, expecting the patch to be restored later using RestoreLock
    55  func (g *PatchGuard) UnpatchLock() {
    56  	g.lock.Lock()
    57  	unpatchValue(g.target)
    58  }
    59  
    60  // RestoreLock restores the replacement function in a thread-safe way. Must only be called if UnpatchLock was used previously
    61  func (g *PatchGuard) RestoreLock() {
    62  	patchValue(g.target, g.replacement)
    63  	g.lock.Unlock()
    64  }
    65  
    66  // Patch replaces a function with another
    67  func Patch(target, replacement interface{}) *PatchGuard {
    68  	t := reflect.ValueOf(target)
    69  	r := reflect.ValueOf(replacement)
    70  	patchValue(t, r)
    71  
    72  	return &PatchGuard{sync.Mutex{}, t, r}
    73  }
    74  
    75  // PatchInstanceMethod replaces an instance method methodName for the type target with replacement
    76  // Replacement should expect the receiver (of type target) as the first argument
    77  func PatchInstanceMethod(target reflect.Type, methodName string, replacement interface{}) *PatchGuard {
    78  	m, ok := target.MethodByName(methodName)
    79  	if !ok {
    80  		panic(fmt.Sprintf("unknown method %s", methodName))
    81  	}
    82  	r := reflect.ValueOf(replacement)
    83  	patchValue(m.Func, r)
    84  
    85  	return &PatchGuard{sync.Mutex{}, m.Func, r}
    86  }
    87  
    88  // PatchCount returns the number of currently patched functions and type methods
    89  func PatchCount() int {
    90  	lock.Lock()
    91  	defer lock.Unlock()
    92  	return len(patches)
    93  }
    94  
    95  func patchValue(target, replacement reflect.Value) {
    96  	lock.Lock()
    97  	defer lock.Unlock()
    98  
    99  	if target.Kind() != reflect.Func {
   100  		panic("target has to be a Func")
   101  	}
   102  
   103  	if replacement.Kind() != reflect.Func {
   104  		panic("replacement has to be a Func")
   105  	}
   106  
   107  	if target.Type() != replacement.Type() {
   108  		panic(fmt.Sprintf("target and replacement have to have the same type %s != %s", target.Type(), replacement.Type()))
   109  	}
   110  
   111  	// once the world is stopped, we cannot panic, or we will be forced to crash
   112  	stopTheWorld("monkey patch")
   113  	defer startTheWorld()
   114  
   115  	if patch, ok := patches[target.Pointer()]; ok {
   116  		unpatch(target.Pointer(), patch)
   117  	}
   118  
   119  	bytes := replaceFunction(target.Pointer(), (uintptr)(getPtr(replacement)))
   120  	patches[target.Pointer()] = patch{bytes, &replacement}
   121  }
   122  
   123  // Unpatch removes any monkey patches on target
   124  // returns whether target was patched in the first place
   125  func Unpatch(target interface{}) bool {
   126  	return unpatchValue(reflect.ValueOf(target))
   127  }
   128  
   129  // UnpatchInstanceMethod removes the patch on methodName of the target
   130  // returns whether it was patched in the first place
   131  func UnpatchInstanceMethod(target reflect.Type, methodName string) bool {
   132  	m, ok := target.MethodByName(methodName)
   133  	if !ok {
   134  		panic(fmt.Sprintf("unknown method %s", methodName))
   135  	}
   136  	return unpatchValue(m.Func)
   137  }
   138  
   139  // UnpatchAll removes all applied monkeypatches
   140  func UnpatchAll() {
   141  	lock.Lock()
   142  	defer lock.Unlock()
   143  	stopTheWorld("monkey unpatch all")
   144  	defer startTheWorld()
   145  	for target, p := range patches {
   146  		unpatch(target, p)
   147  		delete(patches, target)
   148  	}
   149  }
   150  
   151  // Unpatch removes a monkeypatch from the specified function
   152  // returns whether the function was patched in the first place
   153  func unpatchValue(target reflect.Value) bool {
   154  	lock.Lock()
   155  	defer lock.Unlock()
   156  	stopTheWorld("monkey unpatch")
   157  	defer startTheWorld()
   158  	patch, ok := patches[target.Pointer()]
   159  	if !ok {
   160  		return false
   161  	}
   162  	unpatch(target.Pointer(), patch)
   163  	delete(patches, target.Pointer())
   164  	return true
   165  }
   166  
   167  func unpatch(target uintptr, p patch) {
   168  	copyToLocation(target, p.originalBytes)
   169  }